Tomtom84 commited on
Commit
94f10a6
·
verified ·
1 Parent(s): 3d65908

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -116,18 +116,21 @@ async def tts(ws: WebSocket):
116
  offset_len = ids.size(1) # wie viele Tokens existieren schon
117
  last_tok = None
118
  buf = []
119
- masker.buffer_pos = 0 # Initialize buffer position for masker
120
 
121
  while True:
122
- # --- Mini‑Generate (Cache Disabled for Debugging) -------------------------------------------
 
 
 
123
  gen = model.generate(
124
- input_ids = ids, # Always use full sequence
125
- attention_mask = attn, # Always use full attention mask
126
- # past_key_values= past, # Disabled cache
127
  max_new_tokens = CHUNK_TOKENS,
128
  logits_processor=[masker],
129
  do_sample=True, temperature=0.7, top_p=0.95,
130
- use_cache=False, # Disabled cache
131
  return_dict_in_generate=True,
132
  return_legacy_cache=True
133
  )
@@ -139,10 +142,10 @@ async def tts(ws: WebSocket):
139
  break
140
  offset_len += len(new)
141
 
142
- # ----- Update ids and attn with the full sequence (Cache Disabled) ---------
143
- ids = torch.tensor([seq], device=device)
144
- attn = torch.ones_like(ids)
145
- # past = gen.past_key_values # Disabled cache access
146
  last_tok = new[-1]
147
 
148
  print("new tokens:", new[:25], flush=True)
@@ -155,12 +158,12 @@ async def tts(ws: WebSocket):
155
  buf.clear()
156
  continue
157
  buf.append(t - AUDIO_BASE)
158
- masker.buffer_pos += 1 # Increment buffer position
159
  if len(buf) == 7:
160
  await ws.send_bytes(decode_block(buf))
161
  buf.clear()
162
  masker.sent_blocks = 1 # ab jetzt EOS zulässig
163
- masker.buffer_pos = 0 # Reset buffer position after sending a block
164
 
165
  except (StopIteration, WebSocketDisconnect):
166
  pass
 
116
  offset_len = ids.size(1) # wie viele Tokens existieren schon
117
  last_tok = None
118
  buf = []
119
+ # masker.buffer_pos = 0 # Removed initialization here
120
 
121
  while True:
122
+ # Update buffer_pos based on current buffer length before generation
123
+ masker.buffer_pos = len(buf)
124
+
125
+ # --- Mini‑Generate (Cache Re-enabled) -------------------------------------------
126
  gen = model.generate(
127
+ input_ids = ids if past is None else torch.tensor([[last_tok]], device=device), # Reverted
128
+ attention_mask = attn if past is None else None, # Reverted
129
+ past_key_values = past, # Re-enabled cache
130
  max_new_tokens = CHUNK_TOKENS,
131
  logits_processor=[masker],
132
  do_sample=True, temperature=0.7, top_p=0.95,
133
+ use_cache=True, # Re-enabled cache
134
  return_dict_in_generate=True,
135
  return_legacy_cache=True
136
  )
 
142
  break
143
  offset_len += len(new)
144
 
145
+ # ----- Update past and last_tok (Cache Re-enabled) ---------
146
+ # ids = torch.tensor([seq], device=device) # Removed
147
+ # attn = torch.ones_like(ids) # Removed
148
+ past = gen.past_key_values # Re-enabled cache access
149
  last_tok = new[-1]
150
 
151
  print("new tokens:", new[:25], flush=True)
 
158
  buf.clear()
159
  continue
160
  buf.append(t - AUDIO_BASE)
161
+ # masker.buffer_pos += 1 # Removed increment here
162
  if len(buf) == 7:
163
  await ws.send_bytes(decode_block(buf))
164
  buf.clear()
165
  masker.sent_blocks = 1 # ab jetzt EOS zulässig
166
+ # masker.buffer_pos = 0 # Removed reset here
167
 
168
  except (StopIteration, WebSocketDisconnect):
169
  pass