Tomtom84 commited on
Commit
0238891
·
verified ·
1 Parent(s): 2fa4182

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -13
app.py CHANGED
@@ -3,7 +3,6 @@ import os, json, torch, asyncio
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
6
- from transformers.generation.utils import Cache
7
  from snac import SNAC
8
 
9
  # 0) Login + Device ---------------------------------------------------
@@ -108,27 +107,30 @@ async def tts(ws: WebSocket):
108
  buf = []
109
 
110
  while True:
111
- # --- Mini‑Generate -------------------------------------------
112
  gen = model.generate(
113
- input_ids = ids if past is None else torch.tensor([[last_tok]], device=device),
114
- attention_mask = attn if past is None else None,
115
- past_key_values = past,
116
- max_new_tokens = CHUNK_TOKENS,
117
- logits_processor= [masker],
118
  do_sample=True, temperature=0.7, top_p=0.95,
119
- use_cache=True,
120
- return_dict_in_generate=True, # Added return_dict_in_generate
121
- return_legacy_cache=True # Added legacy cache
122
  )
123
 
124
  # ----- neue Tokens heraus schneiden --------------------------
125
- new = gen.sequences[0][offset_len:].tolist() # Access sequences attribute
 
126
  if not new: # nichts -> fertig
127
  break
128
  offset_len += len(new)
129
 
130
- # ----- weiter mit Cache (letzte PKV steht im Modell) ---------
131
- past = gen.past_key_values # Corrected cache access
 
 
132
  last_tok = new[-1]
133
 
134
  print("new tokens:", new[:25], flush=True)
 
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
 
6
  from snac import SNAC
7
 
8
  # 0) Login + Device ---------------------------------------------------
 
107
  buf = []
108
 
109
  while True:
110
+ # --- Mini‑Generate (Cache Disabled for Debugging) -------------------------------------------
111
  gen = model.generate(
112
+ input_ids = ids, # Always use full sequence
113
+ attention_mask = attn, # Always use full attention mask
114
+ # past_key_values= past, # Disabled cache
115
+ max_new_tokens = CHUNK_TOKENS,
116
+ logits_processor=[masker],
117
  do_sample=True, temperature=0.7, top_p=0.95,
118
+ use_cache=False, # Disabled cache
119
+ return_dict_in_generate=True,
120
+ return_legacy_cache=True
121
  )
122
 
123
  # ----- neue Tokens heraus schneiden --------------------------
124
+ seq = gen.sequences[0].tolist()
125
+ new = seq[offset_len:]
126
  if not new: # nichts -> fertig
127
  break
128
  offset_len += len(new)
129
 
130
+ # ----- Update ids and attn with the full sequence (Cache Disabled) ---------
131
+ ids = torch.tensor([seq], device=device)
132
+ attn = torch.ones_like(ids)
133
+ # past = gen.past_key_values # Disabled cache access
134
  last_tok = new[-1]
135
 
136
  print("new tokens:", new[:25], flush=True)