Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,6 @@ import os, json, torch, asyncio
|
|
3 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
4 |
from huggingface_hub import login
|
5 |
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
|
6 |
-
from transformers.generation.utils import Cache
|
7 |
from snac import SNAC
|
8 |
|
9 |
# 0) Login + Device ---------------------------------------------------
|
@@ -108,27 +107,30 @@ async def tts(ws: WebSocket):
|
|
108 |
buf = []
|
109 |
|
110 |
while True:
|
111 |
-
# --- Mini‑Generate -------------------------------------------
|
112 |
gen = model.generate(
|
113 |
-
input_ids
|
114 |
-
attention_mask
|
115 |
-
past_key_values
|
116 |
-
max_new_tokens
|
117 |
-
logits_processor=
|
118 |
do_sample=True, temperature=0.7, top_p=0.95,
|
119 |
-
use_cache=
|
120 |
-
return_dict_in_generate=True,
|
121 |
-
return_legacy_cache=True
|
122 |
)
|
123 |
|
124 |
# ----- neue Tokens heraus schneiden --------------------------
|
125 |
-
|
|
|
126 |
if not new: # nichts -> fertig
|
127 |
break
|
128 |
offset_len += len(new)
|
129 |
|
130 |
-
# -----
|
131 |
-
|
|
|
|
|
132 |
last_tok = new[-1]
|
133 |
|
134 |
print("new tokens:", new[:25], flush=True)
|
|
|
3 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
4 |
from huggingface_hub import login
|
5 |
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
|
|
|
6 |
from snac import SNAC
|
7 |
|
8 |
# 0) Login + Device ---------------------------------------------------
|
|
|
107 |
buf = []
|
108 |
|
109 |
while True:
|
110 |
+
# --- Mini‑Generate (Cache Disabled for Debugging) -------------------------------------------
|
111 |
gen = model.generate(
|
112 |
+
input_ids = ids, # Always use full sequence
|
113 |
+
attention_mask = attn, # Always use full attention mask
|
114 |
+
# past_key_values= past, # Disabled cache
|
115 |
+
max_new_tokens = CHUNK_TOKENS,
|
116 |
+
logits_processor=[masker],
|
117 |
do_sample=True, temperature=0.7, top_p=0.95,
|
118 |
+
use_cache=False, # Disabled cache
|
119 |
+
return_dict_in_generate=True,
|
120 |
+
return_legacy_cache=True
|
121 |
)
|
122 |
|
123 |
# ----- neue Tokens heraus schneiden --------------------------
|
124 |
+
seq = gen.sequences[0].tolist()
|
125 |
+
new = seq[offset_len:]
|
126 |
if not new: # nichts -> fertig
|
127 |
break
|
128 |
offset_len += len(new)
|
129 |
|
130 |
+
# ----- Update ids and attn with the full sequence (Cache Disabled) ---------
|
131 |
+
ids = torch.tensor([seq], device=device)
|
132 |
+
attn = torch.ones_like(ids)
|
133 |
+
# past = gen.past_key_values # Disabled cache access
|
134 |
last_tok = new[-1]
|
135 |
|
136 |
print("new tokens:", new[:25], flush=True)
|