Tomtom84 commited on
Commit
10be82b
·
verified ·
1 Parent(s): 7b0d42c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -12
app.py CHANGED
@@ -115,21 +115,64 @@ def redistribute_codes(code_list: list[int]) -> np.ndarray:
115
  async def tts_ws(ws: WebSocket):
116
  await ws.accept()
117
  try:
118
- msg = await ws.receive_text()
119
- req = json.loads(msg)
120
  text = req.get("text", "")
121
  voice = req.get("voice", "")
122
 
123
- # 1) Prompt → Codes → Audio
124
- with torch.no_grad():
125
- codes = process_single_prompt(text, voice)
126
- audio_np = redistribute_codes(codes)
127
-
128
- # 2) In PCM16 wandeln & senden
129
- pcm16 = (audio_np * 32767).astype(np.int16).tobytes()
130
- await ws.send_bytes(pcm16)
131
-
132
- # 3) sauber schließen
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  await ws.close()
134
 
135
  except WebSocketDisconnect:
 
115
  async def tts_ws(ws: WebSocket):
116
  await ws.accept()
117
  try:
118
+ msg = await ws.receive_text()
119
+ req = json.loads(msg)
120
  text = req.get("text", "")
121
  voice = req.get("voice", "")
122
 
123
+ # 1) Prompt vorbereiten
124
+ input_ids, attention_mask = prepare_inputs(text, voice)
125
+ past_kvs = None
126
+ buffer = []
127
+
128
+ # 2) Token‑für‑Token (oder in kleinen Blöcken)
129
+ while True:
130
+ # Nur max_new_tokens=50 pro Aufruf
131
+ out = model.generate(
132
+ input_ids=input_ids if past_kvs is None else None,
133
+ attention_mask=attention_mask if past_kvs is None else None,
134
+ past_key_values=past_kvs,
135
+ use_cache=True,
136
+ do_sample=True,
137
+ temperature=0.7,
138
+ top_p=0.95,
139
+ repetition_penalty=1.1,
140
+ max_new_tokens=50,
141
+ eos_token_id=128258,
142
+ return_dict_in_generate=True,
143
+ output_past_key_values=True,
144
+ return_legacy_cache=True, # falls Ihr noch das alte past_key_values-Format braucht
145
+ )
146
+
147
+ # Extrahiere neue Token (ohne die already generated ones)
148
+ new_ids = out.sequences[0, input_ids.shape[-1]:].tolist()
149
+ past_kvs = out.past_key_values
150
+
151
+ for tok in new_ids:
152
+ if tok == model.config.eos_token_id:
153
+ # Stream zu Ende
154
+ break
155
+ if tok == 128257: # Reset-Start‑Marker
156
+ buffer = []
157
+ continue
158
+ buffer.append(tok - AUDIO_OFFSET)
159
+
160
+ # Sobald wir 7 Audio‑Codes gesammelt haben → dekodieren & schicken
161
+ if len(buffer) == 7:
162
+ pcm = decode_block(buffer)
163
+ buffer = []
164
+ await ws.send_bytes(pcm)
165
+
166
+ # Wenn EOS im Chunk war, abbrechen
167
+ if model.config.eos_token_id in new_ids:
168
+ break
169
+
170
+ # Danach weiter mit nächsten 50 Tokens,
171
+ # input_ids & attention_mask nur beim ersten Aufruf nötig
172
+ input_ids = None
173
+ attention_mask = None
174
+
175
+ # 3) Am Ende WebSocket sauber schließen
176
  await ws.close()
177
 
178
  except WebSocketDisconnect: