Tomtom84 commited on
Commit
1d792aa
·
verified ·
1 Parent(s): cd13e5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -79
app.py CHANGED
@@ -18,7 +18,6 @@ model = None
18
  snac = None
19
  masker = None
20
  stopping_criteria = None
21
- # actual_eos_token_id = None # Reverted to constant below
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
  # 0) Login + Device ---------------------------------------------------
@@ -33,26 +32,23 @@ if HF_TOKEN:
33
  REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
34
  START_TOKEN = 128259
35
  NEW_BLOCK = 128257
36
- # --- Reverted to using the hardcoded EOS token based on user belief ---
37
- #EOS_TOKEN = 128258
38
- # --- End Reverted EOS Token ---
39
  AUDIO_BASE = 128266
40
  AUDIO_SPAN = 4096 * 7 # 28672 Codes
41
  CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
42
- # Create AUDIO_IDS on the correct device later in load_models
43
  AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
44
 
45
  # 2) Logit‑Mask -------------------------------------------------------
46
- # Uses the constant EOS_TOKEN
47
  class AudioMask(LogitsProcessor):
48
  def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
49
  super().__init__()
50
- new_block_tensor = torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long)
51
- eos_tensor = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long)
52
- self.allow = torch.cat([new_block_tensor, audio_ids], dim=0)
53
- self.eos = eos_tensor
 
54
  self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0)
55
- self.sent_blocks = 0
56
 
57
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
58
  current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
@@ -64,43 +60,36 @@ class AudioMask(LogitsProcessor):
64
  self.sent_blocks = 0
65
 
66
  # 3) StoppingCriteria für EOS ---------------------------------------
67
- # Uses the constant EOS_TOKEN
68
  class EosStoppingCriteria(StoppingCriteria):
69
  def __init__(self, eos_token_id: int):
70
  self.eos_token_id = eos_token_id
71
- # No warning needed here as we are intentionally using the constant
72
 
73
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
74
- if self.eos_token_id is None:
75
- return False
76
  if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
77
- print(f"StoppingCriteria: EOS detected (ID: {self.eos_token_id}).") # Add log
78
  return True
79
  return False
80
 
81
  # 4) Benutzerdefinierter AudioStreamer -------------------------------
82
  class AudioStreamer(BaseStreamer):
83
- # Pass the constant EOS_TOKEN here too
84
- def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str, eos_token_id: int):
85
  self.ws = ws
86
  self.snac = snac_decoder
87
  self.masker = audio_mask
88
  self.loop = loop
89
  self.device = target_device
90
- self.eos_token_id = eos_token_id # Store constant EOS ID
91
  self.buf: list[int] = []
92
  self.tasks = set()
93
 
94
  def _decode_block(self, block7: list[int]) -> bytes:
95
  """
96
  Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
97
- NOTE: Extracts base code value (0-4095) using modulo, assuming
98
- input values represent (slot_offset + code_value).
99
- Maps extracted values using the structure potentially correct for Kartoffel_Orpheus.
100
  """
101
  if len(block7) != 7:
102
- # print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
103
- return b"" # Less verbose logging
104
 
105
  try:
106
  # --- Extract base code value (0 to CODEBOOK_SIZE-1) for each slot using modulo ---
@@ -140,7 +129,10 @@ class AudioStreamer(BaseStreamer):
140
  audio = self.snac.decode(codes)[0]
141
  except Exception as e_decode:
142
  print(f"Streamer Error: Exception during snac.decode: {e_decode}")
143
- # Add more details if needed, e.g., shapes: {[c.shape for c in codes]}
 
 
 
144
  return b""
145
 
146
  # --- Post-processing ---
@@ -159,16 +151,14 @@ class AudioStreamer(BaseStreamer):
159
  try:
160
  await self.ws.send_bytes(data)
161
  except WebSocketDisconnect:
162
- # This is expected if client disconnects first, don't log error
163
- # print("Streamer: WebSocket disconnected during send.")
164
- pass
165
  except Exception as e:
166
- if "Cannot call \"send\" once a close message has been sent" in str(e) or \
167
- "Connection is closed" in str(e):
168
- # This is expected if client disconnects during generation, suppress repetitive logs
169
- pass
170
- else:
171
  print(f"Streamer: Error sending bytes: {e}")
 
 
 
172
 
173
  def put(self, value: torch.LongTensor):
174
  """
@@ -177,40 +167,56 @@ class AudioStreamer(BaseStreamer):
177
  """
178
  if value.numel() == 0:
179
  return
180
- new_token_ids = value.squeeze().cpu().tolist()
 
181
  if isinstance(new_token_ids, int):
182
  new_token_ids = [new_token_ids]
183
 
184
  for t in new_token_ids:
185
- # No need to check for EOS here, StoppingCriteria handles it
 
 
 
 
 
 
 
 
186
  if t == NEW_BLOCK:
 
187
  self.buf.clear()
188
- continue
189
 
190
- # Use the constant EOS_TOKEN for comparison if needed (e.g. for logging)
191
  if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
192
  self.buf.append(t - AUDIO_BASE) # Store value relative to base
193
- # else: # Optionally log ignored tokens
194
- # if t != self.eos_token_id: # Don't warn about the EOS token itself
195
- # print(f"Streamer Warning: Ignoring unexpected token {t}")
196
 
 
197
  if len(self.buf) == 7:
198
  audio_bytes = self._decode_block(self.buf)
199
- self.buf.clear()
200
 
201
- if audio_bytes:
 
202
  future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
203
  self.tasks.add(future)
 
204
  future.add_done_callback(self.tasks.discard)
205
 
 
206
  if self.masker.sent_blocks == 0:
207
- self.masker.sent_blocks = 1
 
208
 
209
  def end(self):
210
  """Called by generate() when generation finishes."""
211
  if len(self.buf) > 0:
212
  print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
213
  self.buf.clear()
 
214
  pass
215
 
216
  # 5) FastAPI App ------------------------------------------------------
@@ -218,8 +224,7 @@ app = FastAPI()
218
 
219
  @app.on_event("startup")
220
  async def load_models_startup():
221
- # Keep global references, but EOS_TOKEN is now a constant again
222
- global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU
223
 
224
  print(f"🚀 Starting up on device: {device}")
225
  print("⏳ Lade Modelle …", flush=True)
@@ -245,31 +250,41 @@ async def load_models_startup():
245
  torch_dtype=model_dtype,
246
  low_cpu_mem_usage=True,
247
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  print(f"Model loaded to {model.device} with dtype {model.dtype}.")
249
  model.eval()
250
 
251
- # --- Print comparison for EOS token IDs but use the constant ---
252
- conf_eos = model.config.eos_token_id
253
- tok_eos = tok.eos_token_id
254
- print(f"Model Config EOS ID: {conf_eos}")
255
- print(f"Tokenizer EOS ID: {tok_eos}")
256
- print(f"Using Constant EOS_TOKEN: {EOS_TOKEN}") # State the used constant
257
- if conf_eos != EOS_TOKEN or tok_eos != EOS_TOKEN:
258
- print(f"⚠️ WARNING: Constant EOS_TOKEN {EOS_TOKEN} differs from model/tokenizer IDs ({conf_eos}/{tok_eos}).")
259
- # --- End EOS comparison ---
260
-
261
- # Set pad_token_id if None (use the constant EOS)
262
- if model.config.pad_token_id is None:
263
- print(f"Setting model.config.pad_token_id to Constant EOS token ID ({EOS_TOKEN})")
264
- model.config.pad_token_id = EOS_TOKEN
265
-
266
  audio_ids_device = AUDIO_IDS_CPU.to(device)
267
- # Pass the constant EOS_TOKEN to the mask
268
- masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN)
269
  print("AudioMask initialized.")
270
 
271
- # Pass the constant EOS_TOKEN to the stopping criteria
272
- stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)])
273
  print("StoppingCriteria initialized.")
274
 
275
  print("✅ Modelle geladen und bereit!", flush=True)
@@ -296,7 +311,6 @@ def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]:
296
  # 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
297
  @app.websocket("/ws/tts")
298
  async def tts(ws: WebSocket):
299
- # No need for global actual_eos_token_id
300
  await ws.accept()
301
  print("🔌 Client connected")
302
  streamer = None
@@ -317,28 +331,25 @@ async def tts(ws: WebSocket):
317
  print(f"Generating audio for: '{text}' with voice '{voice}'")
318
  ids, attn = build_prompt(text, voice)
319
  masker.reset()
320
- # Pass the constant EOS_TOKEN to streamer
321
- streamer = AudioStreamer(ws, snac, masker, main_loop, device, EOS_TOKEN)
322
 
323
  print("Starting generation in background thread...")
324
- # Use sampling parameters with anti-repetition measures
325
  await asyncio.to_thread(
326
  model.generate,
327
  input_ids=ids,
328
  attention_mask=attn,
329
- max_new_tokens=2500, # Or adjust as needed
330
  logits_processor=[masker],
331
  stopping_criteria=stopping_criteria,
332
- # --- Sampling Parameters with Anti-Repetition ---
333
  do_sample=True,
334
- temperature=0.6, # Adjust if needed
335
- top_p=0.9, # Adjust if needed
336
- repetition_penalty=1.2, # Increased (experiment!)
337
- no_repeat_ngram_size=4, # Added (experiment!)
338
- # --- End Sampling Parameters ---
339
  use_cache=True,
340
- streamer=streamer,
341
- eos_token_id=EOS_TOKEN # Explicitly pass constant EOS ID
342
  )
343
  print("Generation thread finished.")
344
 
@@ -371,8 +382,7 @@ async def tts(ws: WebSocket):
371
  try:
372
  await ws.close(code=1000)
373
  except RuntimeError as e_close:
374
- if "Cannot call \"send\"" not in str(e_close) and "Connection is closed" not in str(e_close):
375
- print(f"Runtime error closing websocket: {e_close}")
376
  except Exception as e_close_final:
377
  print(f"Error closing websocket: {e_close_final}")
378
  elif ws.client_state.name != "DISCONNECTED":
@@ -383,4 +393,6 @@ async def tts(ws: WebSocket):
383
  if __name__ == "__main__":
384
  import uvicorn
385
  print("Starting Uvicorn server...")
 
 
386
  uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")
 
18
  snac = None
19
  masker = None
20
  stopping_criteria = None
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
  # 0) Login + Device ---------------------------------------------------
 
32
  REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
33
  START_TOKEN = 128259
34
  NEW_BLOCK = 128257
35
+ EOS_TOKEN = 128258 # Ensure this is correct for the model
 
 
36
  AUDIO_BASE = 128266
37
  AUDIO_SPAN = 4096 * 7 # 28672 Codes
38
  CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
 
39
  AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
40
 
41
  # 2) Logit‑Mask -------------------------------------------------------
 
42
  class AudioMask(LogitsProcessor):
43
  def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
44
  super().__init__()
45
+ self.allow = torch.cat([
46
+ torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long),
47
+ audio_ids
48
+ ], dim=0)
49
+ self.eos = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long)
50
  self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0)
51
+ self.sent_blocks = 0 # State: Number of audio blocks sent
52
 
53
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
54
  current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
 
60
  self.sent_blocks = 0
61
 
62
  # 3) StoppingCriteria für EOS ---------------------------------------
 
63
  class EosStoppingCriteria(StoppingCriteria):
64
  def __init__(self, eos_token_id: int):
65
  self.eos_token_id = eos_token_id
 
66
 
67
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
 
68
  if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
69
+ # print("StoppingCriteria: EOS detected.") # Optional: Uncomment for debugging
70
  return True
71
  return False
72
 
73
  # 4) Benutzerdefinierter AudioStreamer -------------------------------
74
  class AudioStreamer(BaseStreamer):
75
+ def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str):
 
76
  self.ws = ws
77
  self.snac = snac_decoder
78
  self.masker = audio_mask
79
  self.loop = loop
80
  self.device = target_device
 
81
  self.buf: list[int] = []
82
  self.tasks = set()
83
 
84
  def _decode_block(self, block7: list[int]) -> bytes:
85
  """
86
  Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
87
+ Uses modulo to extract base code value (0-4095).
88
+ Maps extracted values using the structure potentially correct for Kartoffel_Orpheus.
 
89
  """
90
  if len(block7) != 7:
91
+ print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
92
+ return b""
93
 
94
  try:
95
  # --- Extract base code value (0 to CODEBOOK_SIZE-1) for each slot using modulo ---
 
129
  audio = self.snac.decode(codes)[0]
130
  except Exception as e_decode:
131
  print(f"Streamer Error: Exception during snac.decode: {e_decode}")
132
+ print(f"Input codes shapes: {[c.shape for c in codes]}")
133
+ print(f"Input codes dtypes: {[c.dtype for c in codes]}")
134
+ print(f"Input codes devices: {[c.device for c in codes]}")
135
+ print(f"Input code values (min/max): L1({min(l1)}/{max(l1)}) L2({min(l2)}/{max(l2)}) L3({min(l3)}/{max(l3)})")
136
  return b""
137
 
138
  # --- Post-processing ---
 
151
  try:
152
  await self.ws.send_bytes(data)
153
  except WebSocketDisconnect:
154
+ print("Streamer: WebSocket disconnected during send.")
 
 
155
  except Exception as e:
156
+ # Log errors other than expected disconnects more visibly maybe
157
+ if "Cannot call \"send\" once a close message has been sent" not in str(e):
 
 
 
158
  print(f"Streamer: Error sending bytes: {e}")
159
+ # else: # Optionally print disconnect errors quietly
160
+ # print("Streamer: Attempted send after close.")
161
+ pass # Avoid flooding logs if client disconnects early
162
 
163
  def put(self, value: torch.LongTensor):
164
  """
 
167
  """
168
  if value.numel() == 0:
169
  return
170
+ # Ensure value is on CPU and flatten to a list of ints
171
+ new_token_ids = value.squeeze().cpu().tolist() # Move to CPU before list conversion
172
  if isinstance(new_token_ids, int):
173
  new_token_ids = [new_token_ids]
174
 
175
  for t in new_token_ids:
176
+ # --- DEBUGGING PRINT ---
177
+ # Log every token ID received from the model
178
+ print(f"Streamer received token ID: {t}")
179
+ # --- END DEBUGGING ---
180
+
181
+ if t == EOS_TOKEN:
182
+ # print("Streamer: EOS token encountered.") # Optional debugging
183
+ break # Stop processing this batch if EOS is found
184
+
185
  if t == NEW_BLOCK:
186
+ # print("Streamer: NEW_BLOCK token encountered.") # Optional debugging
187
  self.buf.clear()
188
+ continue # Move to the next token
189
 
190
+ # Check if token is within the expected audio range
191
  if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
192
  self.buf.append(t - AUDIO_BASE) # Store value relative to base
193
+ # else: # Log unexpected tokens if needed
194
+ # print(f"Streamer Warning: Ignoring unexpected token {t} (outside audio range [{AUDIO_BASE}, {AUDIO_BASE + AUDIO_SPAN}))")
195
+ pass
196
 
197
+ # If buffer has 7 tokens, decode and send
198
  if len(self.buf) == 7:
199
  audio_bytes = self._decode_block(self.buf)
200
+ self.buf.clear() # Clear buffer after processing
201
 
202
+ if audio_bytes: # Only send if decoding was successful
203
+ # Schedule the async send function to run on the main event loop
204
  future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
205
  self.tasks.add(future)
206
+ # Optional: Remove completed tasks to prevent memory leak if generation is very long
207
  future.add_done_callback(self.tasks.discard)
208
 
209
+ # Allow EOS only after the first full block has been processed and scheduled for sending
210
  if self.masker.sent_blocks == 0:
211
+ # print("Streamer: First audio block processed, allowing EOS.")
212
+ self.masker.sent_blocks = 1 # Update state in the mask
213
 
214
  def end(self):
215
  """Called by generate() when generation finishes."""
216
  if len(self.buf) > 0:
217
  print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
218
  self.buf.clear()
219
+ # print(f"Streamer: Generation finished.") # Optional debugging
220
  pass
221
 
222
  # 5) FastAPI App ------------------------------------------------------
 
224
 
225
  @app.on_event("startup")
226
  async def load_models_startup():
227
+ global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU, EOS_TOKEN
 
228
 
229
  print(f"🚀 Starting up on device: {device}")
230
  print("⏳ Lade Modelle …", flush=True)
 
250
  torch_dtype=model_dtype,
251
  low_cpu_mem_usage=True,
252
  )
253
+
254
+ # --- Verify EOS Token ---
255
+ # Use the actual EOS token ID from the loaded model/tokenizer config
256
+ config_eos_id = model.config.eos_token_id
257
+ tokenizer_eos_id = tok.eos_token_id
258
+
259
+ if config_eos_id is None:
260
+ print("🚨 WARNING: model.config.eos_token_id is None!")
261
+ # Fallback or default? Let's use the constant for now, but this needs checking.
262
+ final_eos_token_id = EOS_TOKEN
263
+ elif tokenizer_eos_id is not None and config_eos_id != tokenizer_eos_id:
264
+ print(f"⚠️ WARNING: Mismatch! model.config.eos_token_id ({config_eos_id}) != tok.eos_token_id ({tokenizer_eos_id}). Using model config ID.")
265
+ final_eos_token_id = config_eos_id
266
+ else:
267
+ final_eos_token_id = config_eos_id
268
+
269
+ # Update the global constant if it differs or wasn't set properly by config
270
+ if final_eos_token_id != EOS_TOKEN:
271
+ print(f"🔄 Updating EOS_TOKEN constant from {EOS_TOKEN} to {final_eos_token_id}")
272
+ EOS_TOKEN = final_eos_token_id # Update the global constant
273
+
274
+ # Set pad_token_id to the determined EOS token ID
275
+ model.config.pad_token_id = EOS_TOKEN
276
+ print(f"Using EOS Token ID: {EOS_TOKEN}")
277
+ # --- End Verify EOS Token ---
278
+
279
+
280
  print(f"Model loaded to {model.device} with dtype {model.dtype}.")
281
  model.eval()
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  audio_ids_device = AUDIO_IDS_CPU.to(device)
284
+ masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN) # Use updated EOS_TOKEN
 
285
  print("AudioMask initialized.")
286
 
287
+ stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)]) # Use updated EOS_TOKEN
 
288
  print("StoppingCriteria initialized.")
289
 
290
  print("✅ Modelle geladen und bereit!", flush=True)
 
311
  # 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
312
  @app.websocket("/ws/tts")
313
  async def tts(ws: WebSocket):
 
314
  await ws.accept()
315
  print("🔌 Client connected")
316
  streamer = None
 
331
  print(f"Generating audio for: '{text}' with voice '{voice}'")
332
  ids, attn = build_prompt(text, voice)
333
  masker.reset()
334
+ streamer = AudioStreamer(ws, snac, masker, main_loop, device)
 
335
 
336
  print("Starting generation in background thread...")
337
+ # --- DEBUGGING: Adjusted Generation Parameters ---
338
  await asyncio.to_thread(
339
  model.generate,
340
  input_ids=ids,
341
  attention_mask=attn,
342
+ max_new_tokens=1500, # Keep lower for faster debugging cycles initially
343
  logits_processor=[masker],
344
  stopping_criteria=stopping_criteria,
345
+ # --- Adjusted Parameters for Debugging Repetition ---
346
  do_sample=True,
347
+ temperature=0.7, # Slightly higher temperature
348
+ # top_p=0.9, # Commented out top_p for simpler testing
349
+ repetition_penalty=1.2, # Slightly stronger penalty
350
+ # --- End Adjusted Parameters ---
 
351
  use_cache=True,
352
+ streamer=streamer
 
353
  )
354
  print("Generation thread finished.")
355
 
 
382
  try:
383
  await ws.close(code=1000)
384
  except RuntimeError as e_close:
385
+ print(f"Runtime error closing websocket: {e_close}")
 
386
  except Exception as e_close_final:
387
  print(f"Error closing websocket: {e_close_final}")
388
  elif ws.client_state.name != "DISCONNECTED":
 
393
  if __name__ == "__main__":
394
  import uvicorn
395
  print("Starting Uvicorn server...")
396
+ # Note: Consider running with --workers 1 if you face issues with globals/GPU memory
397
+ # uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info", workers=1)
398
  uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")