Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -18,7 +18,6 @@ model = None
|
|
18 |
snac = None
|
19 |
masker = None
|
20 |
stopping_criteria = None
|
21 |
-
# actual_eos_token_id = None # Reverted to constant below
|
22 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
23 |
|
24 |
# 0) Login + Device ---------------------------------------------------
|
@@ -33,26 +32,23 @@ if HF_TOKEN:
|
|
33 |
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
|
34 |
START_TOKEN = 128259
|
35 |
NEW_BLOCK = 128257
|
36 |
-
|
37 |
-
#EOS_TOKEN = 128258
|
38 |
-
# --- End Reverted EOS Token ---
|
39 |
AUDIO_BASE = 128266
|
40 |
AUDIO_SPAN = 4096 * 7 # 28672 Codes
|
41 |
CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
|
42 |
-
# Create AUDIO_IDS on the correct device later in load_models
|
43 |
AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
|
44 |
|
45 |
# 2) Logit‑Mask -------------------------------------------------------
|
46 |
-
# Uses the constant EOS_TOKEN
|
47 |
class AudioMask(LogitsProcessor):
|
48 |
def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
|
49 |
super().__init__()
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
54 |
self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0)
|
55 |
-
self.sent_blocks = 0
|
56 |
|
57 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
58 |
current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
|
@@ -64,43 +60,36 @@ class AudioMask(LogitsProcessor):
|
|
64 |
self.sent_blocks = 0
|
65 |
|
66 |
# 3) StoppingCriteria für EOS ---------------------------------------
|
67 |
-
# Uses the constant EOS_TOKEN
|
68 |
class EosStoppingCriteria(StoppingCriteria):
|
69 |
def __init__(self, eos_token_id: int):
|
70 |
self.eos_token_id = eos_token_id
|
71 |
-
# No warning needed here as we are intentionally using the constant
|
72 |
|
73 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
74 |
-
if self.eos_token_id is None:
|
75 |
-
return False
|
76 |
if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
|
77 |
-
print(
|
78 |
return True
|
79 |
return False
|
80 |
|
81 |
# 4) Benutzerdefinierter AudioStreamer -------------------------------
|
82 |
class AudioStreamer(BaseStreamer):
|
83 |
-
|
84 |
-
def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str, eos_token_id: int):
|
85 |
self.ws = ws
|
86 |
self.snac = snac_decoder
|
87 |
self.masker = audio_mask
|
88 |
self.loop = loop
|
89 |
self.device = target_device
|
90 |
-
self.eos_token_id = eos_token_id # Store constant EOS ID
|
91 |
self.buf: list[int] = []
|
92 |
self.tasks = set()
|
93 |
|
94 |
def _decode_block(self, block7: list[int]) -> bytes:
|
95 |
"""
|
96 |
Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
|
97 |
-
|
98 |
-
|
99 |
-
Maps extracted values using the structure potentially correct for Kartoffel_Orpheus.
|
100 |
"""
|
101 |
if len(block7) != 7:
|
102 |
-
|
103 |
-
return b""
|
104 |
|
105 |
try:
|
106 |
# --- Extract base code value (0 to CODEBOOK_SIZE-1) for each slot using modulo ---
|
@@ -140,7 +129,10 @@ class AudioStreamer(BaseStreamer):
|
|
140 |
audio = self.snac.decode(codes)[0]
|
141 |
except Exception as e_decode:
|
142 |
print(f"Streamer Error: Exception during snac.decode: {e_decode}")
|
143 |
-
|
|
|
|
|
|
|
144 |
return b""
|
145 |
|
146 |
# --- Post-processing ---
|
@@ -159,16 +151,14 @@ class AudioStreamer(BaseStreamer):
|
|
159 |
try:
|
160 |
await self.ws.send_bytes(data)
|
161 |
except WebSocketDisconnect:
|
162 |
-
|
163 |
-
# print("Streamer: WebSocket disconnected during send.")
|
164 |
-
pass
|
165 |
except Exception as e:
|
166 |
-
|
167 |
-
|
168 |
-
# This is expected if client disconnects during generation, suppress repetitive logs
|
169 |
-
pass
|
170 |
-
else:
|
171 |
print(f"Streamer: Error sending bytes: {e}")
|
|
|
|
|
|
|
172 |
|
173 |
def put(self, value: torch.LongTensor):
|
174 |
"""
|
@@ -177,40 +167,56 @@ class AudioStreamer(BaseStreamer):
|
|
177 |
"""
|
178 |
if value.numel() == 0:
|
179 |
return
|
180 |
-
|
|
|
181 |
if isinstance(new_token_ids, int):
|
182 |
new_token_ids = [new_token_ids]
|
183 |
|
184 |
for t in new_token_ids:
|
185 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
if t == NEW_BLOCK:
|
|
|
187 |
self.buf.clear()
|
188 |
-
continue
|
189 |
|
190 |
-
#
|
191 |
if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
|
192 |
self.buf.append(t - AUDIO_BASE) # Store value relative to base
|
193 |
-
# else: #
|
194 |
-
|
195 |
-
|
196 |
|
|
|
197 |
if len(self.buf) == 7:
|
198 |
audio_bytes = self._decode_block(self.buf)
|
199 |
-
self.buf.clear()
|
200 |
|
201 |
-
if audio_bytes:
|
|
|
202 |
future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
|
203 |
self.tasks.add(future)
|
|
|
204 |
future.add_done_callback(self.tasks.discard)
|
205 |
|
|
|
206 |
if self.masker.sent_blocks == 0:
|
207 |
-
|
|
|
208 |
|
209 |
def end(self):
|
210 |
"""Called by generate() when generation finishes."""
|
211 |
if len(self.buf) > 0:
|
212 |
print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
|
213 |
self.buf.clear()
|
|
|
214 |
pass
|
215 |
|
216 |
# 5) FastAPI App ------------------------------------------------------
|
@@ -218,8 +224,7 @@ app = FastAPI()
|
|
218 |
|
219 |
@app.on_event("startup")
|
220 |
async def load_models_startup():
|
221 |
-
|
222 |
-
global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU
|
223 |
|
224 |
print(f"🚀 Starting up on device: {device}")
|
225 |
print("⏳ Lade Modelle …", flush=True)
|
@@ -245,31 +250,41 @@ async def load_models_startup():
|
|
245 |
torch_dtype=model_dtype,
|
246 |
low_cpu_mem_usage=True,
|
247 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
print(f"Model loaded to {model.device} with dtype {model.dtype}.")
|
249 |
model.eval()
|
250 |
|
251 |
-
# --- Print comparison for EOS token IDs but use the constant ---
|
252 |
-
conf_eos = model.config.eos_token_id
|
253 |
-
tok_eos = tok.eos_token_id
|
254 |
-
print(f"Model Config EOS ID: {conf_eos}")
|
255 |
-
print(f"Tokenizer EOS ID: {tok_eos}")
|
256 |
-
print(f"Using Constant EOS_TOKEN: {EOS_TOKEN}") # State the used constant
|
257 |
-
if conf_eos != EOS_TOKEN or tok_eos != EOS_TOKEN:
|
258 |
-
print(f"⚠️ WARNING: Constant EOS_TOKEN {EOS_TOKEN} differs from model/tokenizer IDs ({conf_eos}/{tok_eos}).")
|
259 |
-
# --- End EOS comparison ---
|
260 |
-
|
261 |
-
# Set pad_token_id if None (use the constant EOS)
|
262 |
-
if model.config.pad_token_id is None:
|
263 |
-
print(f"Setting model.config.pad_token_id to Constant EOS token ID ({EOS_TOKEN})")
|
264 |
-
model.config.pad_token_id = EOS_TOKEN
|
265 |
-
|
266 |
audio_ids_device = AUDIO_IDS_CPU.to(device)
|
267 |
-
|
268 |
-
masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN)
|
269 |
print("AudioMask initialized.")
|
270 |
|
271 |
-
|
272 |
-
stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)])
|
273 |
print("StoppingCriteria initialized.")
|
274 |
|
275 |
print("✅ Modelle geladen und bereit!", flush=True)
|
@@ -296,7 +311,6 @@ def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]:
|
|
296 |
# 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
|
297 |
@app.websocket("/ws/tts")
|
298 |
async def tts(ws: WebSocket):
|
299 |
-
# No need for global actual_eos_token_id
|
300 |
await ws.accept()
|
301 |
print("🔌 Client connected")
|
302 |
streamer = None
|
@@ -317,28 +331,25 @@ async def tts(ws: WebSocket):
|
|
317 |
print(f"Generating audio for: '{text}' with voice '{voice}'")
|
318 |
ids, attn = build_prompt(text, voice)
|
319 |
masker.reset()
|
320 |
-
|
321 |
-
streamer = AudioStreamer(ws, snac, masker, main_loop, device, EOS_TOKEN)
|
322 |
|
323 |
print("Starting generation in background thread...")
|
324 |
-
#
|
325 |
await asyncio.to_thread(
|
326 |
model.generate,
|
327 |
input_ids=ids,
|
328 |
attention_mask=attn,
|
329 |
-
max_new_tokens=
|
330 |
logits_processor=[masker],
|
331 |
stopping_criteria=stopping_criteria,
|
332 |
-
# ---
|
333 |
do_sample=True,
|
334 |
-
temperature=0.
|
335 |
-
top_p=0.9,
|
336 |
-
repetition_penalty=1.2,
|
337 |
-
|
338 |
-
# --- End Sampling Parameters ---
|
339 |
use_cache=True,
|
340 |
-
streamer=streamer
|
341 |
-
eos_token_id=EOS_TOKEN # Explicitly pass constant EOS ID
|
342 |
)
|
343 |
print("Generation thread finished.")
|
344 |
|
@@ -371,8 +382,7 @@ async def tts(ws: WebSocket):
|
|
371 |
try:
|
372 |
await ws.close(code=1000)
|
373 |
except RuntimeError as e_close:
|
374 |
-
|
375 |
-
print(f"Runtime error closing websocket: {e_close}")
|
376 |
except Exception as e_close_final:
|
377 |
print(f"Error closing websocket: {e_close_final}")
|
378 |
elif ws.client_state.name != "DISCONNECTED":
|
@@ -383,4 +393,6 @@ async def tts(ws: WebSocket):
|
|
383 |
if __name__ == "__main__":
|
384 |
import uvicorn
|
385 |
print("Starting Uvicorn server...")
|
|
|
|
|
386 |
uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")
|
|
|
18 |
snac = None
|
19 |
masker = None
|
20 |
stopping_criteria = None
|
|
|
21 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
|
23 |
# 0) Login + Device ---------------------------------------------------
|
|
|
32 |
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
|
33 |
START_TOKEN = 128259
|
34 |
NEW_BLOCK = 128257
|
35 |
+
EOS_TOKEN = 128258 # Ensure this is correct for the model
|
|
|
|
|
36 |
AUDIO_BASE = 128266
|
37 |
AUDIO_SPAN = 4096 * 7 # 28672 Codes
|
38 |
CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
|
|
|
39 |
AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
|
40 |
|
41 |
# 2) Logit‑Mask -------------------------------------------------------
|
|
|
42 |
class AudioMask(LogitsProcessor):
|
43 |
def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
|
44 |
super().__init__()
|
45 |
+
self.allow = torch.cat([
|
46 |
+
torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long),
|
47 |
+
audio_ids
|
48 |
+
], dim=0)
|
49 |
+
self.eos = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long)
|
50 |
self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0)
|
51 |
+
self.sent_blocks = 0 # State: Number of audio blocks sent
|
52 |
|
53 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
54 |
current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
|
|
|
60 |
self.sent_blocks = 0
|
61 |
|
62 |
# 3) StoppingCriteria für EOS ---------------------------------------
|
|
|
63 |
class EosStoppingCriteria(StoppingCriteria):
|
64 |
def __init__(self, eos_token_id: int):
|
65 |
self.eos_token_id = eos_token_id
|
|
|
66 |
|
67 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
|
|
|
68 |
if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
|
69 |
+
# print("StoppingCriteria: EOS detected.") # Optional: Uncomment for debugging
|
70 |
return True
|
71 |
return False
|
72 |
|
73 |
# 4) Benutzerdefinierter AudioStreamer -------------------------------
|
74 |
class AudioStreamer(BaseStreamer):
|
75 |
+
def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str):
|
|
|
76 |
self.ws = ws
|
77 |
self.snac = snac_decoder
|
78 |
self.masker = audio_mask
|
79 |
self.loop = loop
|
80 |
self.device = target_device
|
|
|
81 |
self.buf: list[int] = []
|
82 |
self.tasks = set()
|
83 |
|
84 |
def _decode_block(self, block7: list[int]) -> bytes:
|
85 |
"""
|
86 |
Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
|
87 |
+
Uses modulo to extract base code value (0-4095).
|
88 |
+
Maps extracted values using the structure potentially correct for Kartoffel_Orpheus.
|
|
|
89 |
"""
|
90 |
if len(block7) != 7:
|
91 |
+
print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
|
92 |
+
return b""
|
93 |
|
94 |
try:
|
95 |
# --- Extract base code value (0 to CODEBOOK_SIZE-1) for each slot using modulo ---
|
|
|
129 |
audio = self.snac.decode(codes)[0]
|
130 |
except Exception as e_decode:
|
131 |
print(f"Streamer Error: Exception during snac.decode: {e_decode}")
|
132 |
+
print(f"Input codes shapes: {[c.shape for c in codes]}")
|
133 |
+
print(f"Input codes dtypes: {[c.dtype for c in codes]}")
|
134 |
+
print(f"Input codes devices: {[c.device for c in codes]}")
|
135 |
+
print(f"Input code values (min/max): L1({min(l1)}/{max(l1)}) L2({min(l2)}/{max(l2)}) L3({min(l3)}/{max(l3)})")
|
136 |
return b""
|
137 |
|
138 |
# --- Post-processing ---
|
|
|
151 |
try:
|
152 |
await self.ws.send_bytes(data)
|
153 |
except WebSocketDisconnect:
|
154 |
+
print("Streamer: WebSocket disconnected during send.")
|
|
|
|
|
155 |
except Exception as e:
|
156 |
+
# Log errors other than expected disconnects more visibly maybe
|
157 |
+
if "Cannot call \"send\" once a close message has been sent" not in str(e):
|
|
|
|
|
|
|
158 |
print(f"Streamer: Error sending bytes: {e}")
|
159 |
+
# else: # Optionally print disconnect errors quietly
|
160 |
+
# print("Streamer: Attempted send after close.")
|
161 |
+
pass # Avoid flooding logs if client disconnects early
|
162 |
|
163 |
def put(self, value: torch.LongTensor):
|
164 |
"""
|
|
|
167 |
"""
|
168 |
if value.numel() == 0:
|
169 |
return
|
170 |
+
# Ensure value is on CPU and flatten to a list of ints
|
171 |
+
new_token_ids = value.squeeze().cpu().tolist() # Move to CPU before list conversion
|
172 |
if isinstance(new_token_ids, int):
|
173 |
new_token_ids = [new_token_ids]
|
174 |
|
175 |
for t in new_token_ids:
|
176 |
+
# --- DEBUGGING PRINT ---
|
177 |
+
# Log every token ID received from the model
|
178 |
+
print(f"Streamer received token ID: {t}")
|
179 |
+
# --- END DEBUGGING ---
|
180 |
+
|
181 |
+
if t == EOS_TOKEN:
|
182 |
+
# print("Streamer: EOS token encountered.") # Optional debugging
|
183 |
+
break # Stop processing this batch if EOS is found
|
184 |
+
|
185 |
if t == NEW_BLOCK:
|
186 |
+
# print("Streamer: NEW_BLOCK token encountered.") # Optional debugging
|
187 |
self.buf.clear()
|
188 |
+
continue # Move to the next token
|
189 |
|
190 |
+
# Check if token is within the expected audio range
|
191 |
if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
|
192 |
self.buf.append(t - AUDIO_BASE) # Store value relative to base
|
193 |
+
# else: # Log unexpected tokens if needed
|
194 |
+
# print(f"Streamer Warning: Ignoring unexpected token {t} (outside audio range [{AUDIO_BASE}, {AUDIO_BASE + AUDIO_SPAN}))")
|
195 |
+
pass
|
196 |
|
197 |
+
# If buffer has 7 tokens, decode and send
|
198 |
if len(self.buf) == 7:
|
199 |
audio_bytes = self._decode_block(self.buf)
|
200 |
+
self.buf.clear() # Clear buffer after processing
|
201 |
|
202 |
+
if audio_bytes: # Only send if decoding was successful
|
203 |
+
# Schedule the async send function to run on the main event loop
|
204 |
future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
|
205 |
self.tasks.add(future)
|
206 |
+
# Optional: Remove completed tasks to prevent memory leak if generation is very long
|
207 |
future.add_done_callback(self.tasks.discard)
|
208 |
|
209 |
+
# Allow EOS only after the first full block has been processed and scheduled for sending
|
210 |
if self.masker.sent_blocks == 0:
|
211 |
+
# print("Streamer: First audio block processed, allowing EOS.")
|
212 |
+
self.masker.sent_blocks = 1 # Update state in the mask
|
213 |
|
214 |
def end(self):
|
215 |
"""Called by generate() when generation finishes."""
|
216 |
if len(self.buf) > 0:
|
217 |
print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
|
218 |
self.buf.clear()
|
219 |
+
# print(f"Streamer: Generation finished.") # Optional debugging
|
220 |
pass
|
221 |
|
222 |
# 5) FastAPI App ------------------------------------------------------
|
|
|
224 |
|
225 |
@app.on_event("startup")
|
226 |
async def load_models_startup():
|
227 |
+
global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU, EOS_TOKEN
|
|
|
228 |
|
229 |
print(f"🚀 Starting up on device: {device}")
|
230 |
print("⏳ Lade Modelle …", flush=True)
|
|
|
250 |
torch_dtype=model_dtype,
|
251 |
low_cpu_mem_usage=True,
|
252 |
)
|
253 |
+
|
254 |
+
# --- Verify EOS Token ---
|
255 |
+
# Use the actual EOS token ID from the loaded model/tokenizer config
|
256 |
+
config_eos_id = model.config.eos_token_id
|
257 |
+
tokenizer_eos_id = tok.eos_token_id
|
258 |
+
|
259 |
+
if config_eos_id is None:
|
260 |
+
print("🚨 WARNING: model.config.eos_token_id is None!")
|
261 |
+
# Fallback or default? Let's use the constant for now, but this needs checking.
|
262 |
+
final_eos_token_id = EOS_TOKEN
|
263 |
+
elif tokenizer_eos_id is not None and config_eos_id != tokenizer_eos_id:
|
264 |
+
print(f"⚠️ WARNING: Mismatch! model.config.eos_token_id ({config_eos_id}) != tok.eos_token_id ({tokenizer_eos_id}). Using model config ID.")
|
265 |
+
final_eos_token_id = config_eos_id
|
266 |
+
else:
|
267 |
+
final_eos_token_id = config_eos_id
|
268 |
+
|
269 |
+
# Update the global constant if it differs or wasn't set properly by config
|
270 |
+
if final_eos_token_id != EOS_TOKEN:
|
271 |
+
print(f"🔄 Updating EOS_TOKEN constant from {EOS_TOKEN} to {final_eos_token_id}")
|
272 |
+
EOS_TOKEN = final_eos_token_id # Update the global constant
|
273 |
+
|
274 |
+
# Set pad_token_id to the determined EOS token ID
|
275 |
+
model.config.pad_token_id = EOS_TOKEN
|
276 |
+
print(f"Using EOS Token ID: {EOS_TOKEN}")
|
277 |
+
# --- End Verify EOS Token ---
|
278 |
+
|
279 |
+
|
280 |
print(f"Model loaded to {model.device} with dtype {model.dtype}.")
|
281 |
model.eval()
|
282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
audio_ids_device = AUDIO_IDS_CPU.to(device)
|
284 |
+
masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN) # Use updated EOS_TOKEN
|
|
|
285 |
print("AudioMask initialized.")
|
286 |
|
287 |
+
stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)]) # Use updated EOS_TOKEN
|
|
|
288 |
print("StoppingCriteria initialized.")
|
289 |
|
290 |
print("✅ Modelle geladen und bereit!", flush=True)
|
|
|
311 |
# 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
|
312 |
@app.websocket("/ws/tts")
|
313 |
async def tts(ws: WebSocket):
|
|
|
314 |
await ws.accept()
|
315 |
print("🔌 Client connected")
|
316 |
streamer = None
|
|
|
331 |
print(f"Generating audio for: '{text}' with voice '{voice}'")
|
332 |
ids, attn = build_prompt(text, voice)
|
333 |
masker.reset()
|
334 |
+
streamer = AudioStreamer(ws, snac, masker, main_loop, device)
|
|
|
335 |
|
336 |
print("Starting generation in background thread...")
|
337 |
+
# --- DEBUGGING: Adjusted Generation Parameters ---
|
338 |
await asyncio.to_thread(
|
339 |
model.generate,
|
340 |
input_ids=ids,
|
341 |
attention_mask=attn,
|
342 |
+
max_new_tokens=1500, # Keep lower for faster debugging cycles initially
|
343 |
logits_processor=[masker],
|
344 |
stopping_criteria=stopping_criteria,
|
345 |
+
# --- Adjusted Parameters for Debugging Repetition ---
|
346 |
do_sample=True,
|
347 |
+
temperature=0.7, # Slightly higher temperature
|
348 |
+
# top_p=0.9, # Commented out top_p for simpler testing
|
349 |
+
repetition_penalty=1.2, # Slightly stronger penalty
|
350 |
+
# --- End Adjusted Parameters ---
|
|
|
351 |
use_cache=True,
|
352 |
+
streamer=streamer
|
|
|
353 |
)
|
354 |
print("Generation thread finished.")
|
355 |
|
|
|
382 |
try:
|
383 |
await ws.close(code=1000)
|
384 |
except RuntimeError as e_close:
|
385 |
+
print(f"Runtime error closing websocket: {e_close}")
|
|
|
386 |
except Exception as e_close_final:
|
387 |
print(f"Error closing websocket: {e_close_final}")
|
388 |
elif ws.client_state.name != "DISCONNECTED":
|
|
|
393 |
if __name__ == "__main__":
|
394 |
import uvicorn
|
395 |
print("Starting Uvicorn server...")
|
396 |
+
# Note: Consider running with --workers 1 if you face issues with globals/GPU memory
|
397 |
+
# uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info", workers=1)
|
398 |
uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")
|