Tomtom84 commited on
Commit
9ef5e61
·
verified ·
1 Parent(s): 10540d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -23
app.py CHANGED
@@ -97,41 +97,42 @@ async def tts(ws: WebSocket):
97
  await ws.accept()
98
  try:
99
  req = json.loads(await ws.receive_text())
100
- text = req.get("text","")
101
- voice = req.get("voice","Jakob")
102
 
103
- ids, attn = build_prompt(text, voice)
104
- past = None
105
- buf = []
106
 
107
  while True:
108
- out = model.generate(
109
  input_ids=ids if past is None else None,
110
  attention_mask=attn if past is None else None,
111
  past_key_values=past,
112
  max_new_tokens=CHUNK_TOKENS,
113
  logits_processor=[MASKER],
114
- do_sample=True, temperature=0.7, top_p=0.95,
115
- use_cache=True,
116
  return_dict_in_generate=True,
 
 
117
  )
118
- pkv = out.past_key_values
119
- if isinstance(pkv, Cache):
120
- pkv = pkv.to_legacy()
121
- past_kvs = pkv
122
- newtok = out.sequences[0,-out.num_generated_tokens:].tolist()
123
-
124
- for t in newtok:
125
- if t==EOS_TOKEN:
126
- raise StopIteration
127
- if t==NEW_BLOCK_TOKEN:
128
- buf.clear(); continue
129
- buf.append(t-AUDIO_BASE)
130
- if len(buf)==7:
131
- await ws.send_bytes(decode_snac(buf))
 
 
132
  buf.clear()
133
 
134
- # ab jetzt nur noch mit Cache weiter‑generieren
135
  ids, attn = None, None
136
 
137
  except (StopIteration, WebSocketDisconnect):
 
97
  await ws.accept()
98
  try:
99
  req = json.loads(await ws.receive_text())
100
+ ids, attn = build_inputs(req.get("text", ""), req.get("voice", "Jakob"))
 
101
 
102
+ past = None # Cache
103
+ buf = []
 
104
 
105
  while True:
106
+ gen = model.generate(
107
  input_ids=ids if past is None else None,
108
  attention_mask=attn if past is None else None,
109
  past_key_values=past,
110
  max_new_tokens=CHUNK_TOKENS,
111
  logits_processor=[MASKER],
112
+ do_sample=True, top_p=0.95, temperature=0.7,
 
113
  return_dict_in_generate=True,
114
+ use_cache=True,
115
+ return_legacy_cache=True, # ⚡ wichtig
116
  )
117
+
118
+ # Legacy‑Cache weitergeben
119
+ past = gen.past_key_values
120
+
121
+ # die tatsächlich erzeugten neuen Tokens
122
+ new_tok = gen.sequences[0, -gen.num_generated_tokens :].tolist()
123
+
124
+ for t in new_tok:
125
+ if t == EOS_TOKEN:
126
+ raise StopAsyncIteration
127
+ if t == NEW_BLOCK_TOKEN:
128
+ buf.clear()
129
+ continue
130
+ buf.append(t - AUDIO_BASE)
131
+ if len(buf) == 7:
132
+ await ws.send_bytes(decode_block(buf))
133
  buf.clear()
134
 
135
+ # ab jetzt nur Cache, keine neuen IDs mehr nötig
136
  ids, attn = None, None
137
 
138
  except (StopIteration, WebSocketDisconnect):