Tomtom84 commited on
Commit
d4b7e0d
·
verified ·
1 Parent(s): 94f10a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -122,15 +122,15 @@ async def tts(ws: WebSocket):
122
  # Update buffer_pos based on current buffer length before generation
123
  masker.buffer_pos = len(buf)
124
 
125
- # --- Mini‑Generate (Cache Re-enabled) -------------------------------------------
126
  gen = model.generate(
127
- input_ids = ids if past is None else torch.tensor([[last_tok]], device=device), # Reverted
128
- attention_mask = attn if past is None else None, # Reverted
129
- past_key_values = past, # Re-enabled cache
130
  max_new_tokens = CHUNK_TOKENS,
131
  logits_processor=[masker],
132
  do_sample=True, temperature=0.7, top_p=0.95,
133
- use_cache=True, # Re-enabled cache
134
  return_dict_in_generate=True,
135
  return_legacy_cache=True
136
  )
@@ -142,10 +142,10 @@ async def tts(ws: WebSocket):
142
  break
143
  offset_len += len(new)
144
 
145
- # ----- Update past and last_tok (Cache Re-enabled) ---------
146
- # ids = torch.tensor([seq], device=device) # Removed
147
- # attn = torch.ones_like(ids) # Removed
148
- past = gen.past_key_values # Re-enabled cache access
149
  last_tok = new[-1]
150
 
151
  print("new tokens:", new[:25], flush=True)
 
122
  # Update buffer_pos based on current buffer length before generation
123
  masker.buffer_pos = len(buf)
124
 
125
+ # --- Mini‑Generate (Cache Disabled for Debugging) -------------------------------------------
126
  gen = model.generate(
127
+ input_ids = ids, # Always use full sequence
128
+ attention_mask = attn, # Always use full attention mask
129
+ # past_key_values= past, # Disabled cache
130
  max_new_tokens = CHUNK_TOKENS,
131
  logits_processor=[masker],
132
  do_sample=True, temperature=0.7, top_p=0.95,
133
+ use_cache=False, # Disabled cache
134
  return_dict_in_generate=True,
135
  return_legacy_cache=True
136
  )
 
142
  break
143
  offset_len += len(new)
144
 
145
+ # ----- Update ids and attn with the full sequence (Cache Disabled) ---------
146
+ ids = torch.tensor([seq], device=device) # Re-added
147
+ attn = torch.ones_like(ids) # Re-added
148
+ # past = gen.past_key_values # Disabled cache access
149
  last_tok = new[-1]
150
 
151
  print("new tokens:", new[:25], flush=True)