Tomtom84 commited on
Commit
5d73119
·
verified ·
1 Parent(s): 5031731

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -11
app.py CHANGED
@@ -107,37 +107,44 @@ def decode_block(block7: list[int]) -> bytes:
107
  async def tts(ws: WebSocket):
108
  await ws.accept()
109
  try:
110
- req = json.loads(await ws.receive_text())
111
  text = req.get("text", "")
112
  voice = req.get("voice", "Jakob")
113
 
114
- ids, attn = build_inputs(text, voice)
115
  past = None
 
116
  buf = []
117
 
118
  while True:
119
  out = model.generate(
120
- input_ids = ids if past is None else None,
121
  attention_mask = attn if past is None else None,
122
  past_key_values = past,
123
  max_new_tokens = CHUNK_TOKENS,
124
- logits_processor= [masker], # ► dynamischer Masker
125
  do_sample=True, temperature=0.7, top_p=0.95,
126
  return_dict_in_generate=True,
127
  use_cache=True,
 
128
  )
129
 
130
- # Cache & neue Tokens extrahieren --------------------------------
131
  pkv = out.past_key_values
132
- if isinstance(pkv, Cache): # HF >= 4.47
133
  pkv = pkv.to_legacy()
134
  past = pkv
135
 
136
  new = out.sequences[0, -out.num_generated_tokens :].tolist()
137
- print("new tokens:", new[:20]) # Debug‑Ausgabe
138
 
139
- # ----------------------------------------------------------------
 
 
 
140
  for t in new:
 
 
141
  if t == EOS_TOKEN:
142
  raise StopIteration
143
 
@@ -149,9 +156,9 @@ async def tts(ws: WebSocket):
149
  if len(buf) == 7:
150
  await ws.send_bytes(decode_block(buf))
151
  buf.clear()
152
- masker.blocks_done += 1 #  jetzt darf ggf. EOS
153
 
154
- # nächsten generate‑Step nur noch mit Cache, keine neuen ids
155
  ids, attn = None, None
156
 
157
  except (StopIteration, WebSocketDisconnect):
@@ -165,7 +172,7 @@ async def tts(ws: WebSocket):
165
  try:
166
  await ws.close()
167
  except RuntimeError:
168
- pass # Close‑Frame war schon raus
169
 
170
  # ── 6. Lokaler Start (uvicorn) ───────────────────────────────────────
171
  if __name__ == "__main__":
 
107
  async def tts(ws: WebSocket):
108
  await ws.accept()
109
  try:
110
+ req = json.loads(await ws.receive_text())
111
  text = req.get("text", "")
112
  voice = req.get("voice", "Jakob")
113
 
114
+ ids, attn = build_inputs(text, voice) # vollständiger Prompt
115
  past = None
116
+ last_tok = None # <- NEU
117
  buf = []
118
 
119
  while True:
120
  out = model.generate(
121
+ input_ids = ids if past is None else torch.tensor([[last_tok]], device=device),
122
  attention_mask = attn if past is None else None,
123
  past_key_values = past,
124
  max_new_tokens = CHUNK_TOKENS,
125
+ logits_processor= [masker],
126
  do_sample=True, temperature=0.7, top_p=0.95,
127
  return_dict_in_generate=True,
128
  use_cache=True,
129
+ return_legacy_cache=True, # <- Warnung unterdrücken
130
  )
131
 
132
+ # ----- Cache & neue Token --------------------------------------
133
  pkv = out.past_key_values
134
+ if isinstance(pkv, Cache): # HF >= 4.47
135
  pkv = pkv.to_legacy()
136
  past = pkv
137
 
138
  new = out.sequences[0, -out.num_generated_tokens :].tolist()
139
+ print("new tokens:", new[:20]) # Debug‑Print
140
 
141
+ if not new: # Safety – nichts erzeugt
142
+ raise StopIteration
143
+
144
+ # ----- Token‑Handling ------------------------------------------
145
  for t in new:
146
+ last_tok = t # speichern für nächste Runde
147
+
148
  if t == EOS_TOKEN:
149
  raise StopIteration
150
 
 
156
  if len(buf) == 7:
157
  await ws.send_bytes(decode_block(buf))
158
  buf.clear()
159
+ masker.blocks_done += 1 # nach 1. Block darf EOS
160
 
161
+ # ab nächster Runde nur 1 Token + Cache
162
  ids, attn = None, None
163
 
164
  except (StopIteration, WebSocketDisconnect):
 
172
  try:
173
  await ws.close()
174
  except RuntimeError:
175
+ pass
176
 
177
  # ── 6. Lokaler Start (uvicorn) ───────────────────────────────────────
178
  if __name__ == "__main__":