Tomtom84 commited on
Commit
d11cc63
·
verified ·
1 Parent(s): 96dc59a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -70
app.py CHANGED
@@ -18,7 +18,7 @@ model = None
18
  snac = None
19
  masker = None
20
  stopping_criteria = None
21
- actual_eos_token_id = None # Will be determined during startup
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
  # 0) Login + Device ---------------------------------------------------
@@ -31,10 +31,11 @@ if HF_TOKEN:
31
 
32
  # 1) Konstanten -------------------------------------------------------
33
  REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
34
- # CHUNK_TOKENS = 50 # Not directly used by us with the streamer approach
35
  START_TOKEN = 128259
36
  NEW_BLOCK = 128257
37
- # EOS_TOKEN = 128258 # REMOVED - Will be determined from model/tokenizer config
 
 
38
  AUDIO_BASE = 128266
39
  AUDIO_SPAN = 4096 * 7 # 28672 Codes
40
  CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
@@ -42,61 +43,51 @@ CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
42
  AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
43
 
44
  # 2) Logit‑Mask -------------------------------------------------------
45
- # Uses the dynamically determined EOS token ID
46
  class AudioMask(LogitsProcessor):
47
  def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
48
  super().__init__()
49
- # Ensure input tensors are Long type for concatenation if needed, although indices are usually int
50
  new_block_tensor = torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long)
51
  eos_tensor = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long)
52
-
53
- # Allow NEW_BLOCK and all valid audio tokens initially
54
  self.allow = torch.cat([new_block_tensor, audio_ids], dim=0)
55
- self.eos = eos_tensor # Store EOS token ID as tensor
56
- self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0) # Precompute combined tensor
57
- self.sent_blocks = 0 # State: Number of audio blocks sent
58
 
59
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
60
- # Determine which tokens are allowed based on whether blocks have been sent
61
  current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
62
-
63
- # Create a mask initialized to negative infinity
64
  mask = torch.full_like(scores, float("-inf"))
65
- # Set allowed token scores to 0 (effectively allowing them)
66
  mask[:, current_allow] = 0
67
- # Apply the mask to the scores
68
  return scores + mask
69
 
70
  def reset(self):
71
- """Resets the state for a new generation request."""
72
  self.sent_blocks = 0
73
 
74
  # 3) StoppingCriteria für EOS ---------------------------------------
75
- # Uses the dynamically determined EOS token ID
76
  class EosStoppingCriteria(StoppingCriteria):
77
  def __init__(self, eos_token_id: int):
78
  self.eos_token_id = eos_token_id
79
- if self.eos_token_id is None:
80
- print("⚠️ EosStoppingCriteria initialized with eos_token_id=None!")
81
 
82
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
83
  if self.eos_token_id is None:
84
- return False # Cannot stop if EOS ID is unknown
85
- # Check if the *last* generated token is the EOS token
86
  if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
87
- # print("StoppingCriteria: EOS detected.")
88
  return True
89
  return False
90
 
91
  # 4) Benutzerdefinierter AudioStreamer -------------------------------
92
  class AudioStreamer(BaseStreamer):
 
93
  def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str, eos_token_id: int):
94
  self.ws = ws
95
  self.snac = snac_decoder
96
  self.masker = audio_mask
97
  self.loop = loop
98
  self.device = target_device
99
- self.eos_token_id = eos_token_id # Store EOS ID for potential use in put (optional)
100
  self.buf: list[int] = []
101
  self.tasks = set()
102
 
@@ -108,8 +99,8 @@ class AudioStreamer(BaseStreamer):
108
  Maps extracted values using the structure potentially correct for Kartoffel_Orpheus.
109
  """
110
  if len(block7) != 7:
111
- print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
112
- return b""
113
 
114
  try:
115
  # --- Extract base code value (0 to CODEBOOK_SIZE-1) for each slot using modulo ---
@@ -129,7 +120,7 @@ class AudioStreamer(BaseStreamer):
129
  except IndexError:
130
  print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}")
131
  return b""
132
- except Exception as e_map: # Catch potential issues with modulo/mapping
133
  print(f"Streamer Error: Exception during code value extraction/mapping: {e_map}. Block: {block7}")
134
  return b""
135
 
@@ -149,10 +140,7 @@ class AudioStreamer(BaseStreamer):
149
  audio = self.snac.decode(codes)[0]
150
  except Exception as e_decode:
151
  print(f"Streamer Error: Exception during snac.decode: {e_decode}")
152
- print(f"Input codes shapes: {[c.shape for c in codes]}")
153
- print(f"Input codes dtypes: {[c.dtype for c in codes]}")
154
- print(f"Input codes devices: {[c.device for c in codes]}")
155
- print(f"Input code values (min/max): L1({min(l1)}/{max(l1)}) L2({min(l2)}/{max(l2)}) L3({min(l3)}/{max(l3)})")
156
  return b""
157
 
158
  # --- Post-processing ---
@@ -171,10 +159,12 @@ class AudioStreamer(BaseStreamer):
171
  try:
172
  await self.ws.send_bytes(data)
173
  except WebSocketDisconnect:
174
- print("Streamer: WebSocket disconnected during send.")
 
 
175
  except Exception as e:
176
- # Handle cases where sending fails after connection closed
177
- if "Cannot call \"send\" once a close message has been sent" in str(e):
178
  # This is expected if client disconnects during generation, suppress repetitive logs
179
  pass
180
  else:
@@ -187,7 +177,6 @@ class AudioStreamer(BaseStreamer):
187
  """
188
  if value.numel() == 0:
189
  return
190
- # Ensure value is on CPU and flatten to a list of ints
191
  new_token_ids = value.squeeze().cpu().tolist()
192
  if isinstance(new_token_ids, int):
193
  new_token_ids = [new_token_ids]
@@ -198,23 +187,22 @@ class AudioStreamer(BaseStreamer):
198
  self.buf.clear()
199
  continue
200
 
 
201
  if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
202
  self.buf.append(t - AUDIO_BASE) # Store value relative to base
203
- # else: # Optionally log ignored tokens outside audio range
204
- # if t != self.eos_token_id: # Don't warn about the EOS token itself
205
- # print(f"Streamer Warning: Ignoring unexpected token {t}")
206
 
207
  if len(self.buf) == 7:
208
  audio_bytes = self._decode_block(self.buf)
209
  self.buf.clear()
210
 
211
  if audio_bytes:
212
- # Schedule the async send function to run on the main event loop
213
  future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
214
  self.tasks.add(future)
215
  future.add_done_callback(self.tasks.discard)
216
 
217
- # Allow EOS only after the first full block has been processed
218
  if self.masker.sent_blocks == 0:
219
  self.masker.sent_blocks = 1
220
 
@@ -230,7 +218,8 @@ app = FastAPI()
230
 
231
  @app.on_event("startup")
232
  async def load_models_startup():
233
- global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU, actual_eos_token_id
 
234
 
235
  print(f"🚀 Starting up on device: {device}")
236
  print("⏳ Lade Modelle …", flush=True)
@@ -259,34 +248,28 @@ async def load_models_startup():
259
  print(f"Model loaded to {model.device} with dtype {model.dtype}.")
260
  model.eval()
261
 
262
- # --- Determine and set the correct EOS token ID ---
263
  conf_eos = model.config.eos_token_id
264
  tok_eos = tok.eos_token_id
265
  print(f"Model Config EOS ID: {conf_eos}")
266
  print(f"Tokenizer EOS ID: {tok_eos}")
 
 
 
 
267
 
268
- if conf_eos is not None:
269
- actual_eos_token_id = conf_eos
270
- elif tok_eos is not None:
271
- actual_eos_token_id = tok_eos
272
- print(f"⚠️ Model config EOS ID is None, using Tokenizer EOS ID: {actual_eos_token_id}")
273
- else:
274
- raise ValueError("Could not determine EOS token ID from model config or tokenizer.")
275
-
276
- print(f"Using EOS Token ID: {actual_eos_token_id}")
277
- # Set pad_token_id to eos_token_id if not already set (common practice for generation)
278
  if model.config.pad_token_id is None:
279
- print(f"Setting model.config.pad_token_id to EOS token ID ({actual_eos_token_id})")
280
- model.config.pad_token_id = actual_eos_token_id
281
- # --- End EOS Token ID determination ---
282
 
283
  audio_ids_device = AUDIO_IDS_CPU.to(device)
284
- # Pass the determined EOS ID to the mask
285
- masker = AudioMask(audio_ids_device, NEW_BLOCK, actual_eos_token_id)
286
  print("AudioMask initialized.")
287
 
288
- # Pass the determined EOS ID to the stopping criteria
289
- stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(actual_eos_token_id)])
290
  print("StoppingCriteria initialized.")
291
 
292
  print("✅ Modelle geladen und bereit!", flush=True)
@@ -313,7 +296,7 @@ def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]:
313
  # 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
314
  @app.websocket("/ws/tts")
315
  async def tts(ws: WebSocket):
316
- global actual_eos_token_id # Ensure we can access the determined EOS ID
317
  await ws.accept()
318
  print("🔌 Client connected")
319
  streamer = None
@@ -334,27 +317,28 @@ async def tts(ws: WebSocket):
334
  print(f"Generating audio for: '{text}' with voice '{voice}'")
335
  ids, attn = build_prompt(text, voice)
336
  masker.reset()
337
- # Pass the determined EOS ID to the streamer as well (optional, for logging/checks)
338
- streamer = AudioStreamer(ws, snac, masker, main_loop, device, actual_eos_token_id)
339
 
340
  print("Starting generation in background thread...")
341
- # Use sampling parameters to avoid repetition
342
  await asyncio.to_thread(
343
  model.generate,
344
  input_ids=ids,
345
  attention_mask=attn,
346
- max_new_tokens=2500, # Increased slightly, adjust as needed
347
  logits_processor=[masker],
348
  stopping_criteria=stopping_criteria,
349
- # --- Sampling Parameters ---
350
  do_sample=True,
351
- temperature=0.6,
352
- top_p=0.9,
353
- repetition_penalty=1.15,
 
354
  # --- End Sampling Parameters ---
355
  use_cache=True,
356
  streamer=streamer,
357
- eos_token_id=actual_eos_token_id # Explicitly pass correct EOS ID here too
358
  )
359
  print("Generation thread finished.")
360
 
@@ -387,8 +371,7 @@ async def tts(ws: WebSocket):
387
  try:
388
  await ws.close(code=1000)
389
  except RuntimeError as e_close:
390
- # Suppress "Cannot call 'send'..." error during final close if already disconnected
391
- if "Cannot call \"send\"" not in str(e_close):
392
  print(f"Runtime error closing websocket: {e_close}")
393
  except Exception as e_close_final:
394
  print(f"Error closing websocket: {e_close_final}")
 
18
  snac = None
19
  masker = None
20
  stopping_criteria = None
21
+ # actual_eos_token_id = None # Reverted to constant below
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
  # 0) Login + Device ---------------------------------------------------
 
31
 
32
  # 1) Konstanten -------------------------------------------------------
33
  REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
 
34
  START_TOKEN = 128259
35
  NEW_BLOCK = 128257
36
+ # --- Reverted to using the hardcoded EOS token based on user belief ---
37
+ EOS_TOKEN = 128258
38
+ # --- End Reverted EOS Token ---
39
  AUDIO_BASE = 128266
40
  AUDIO_SPAN = 4096 * 7 # 28672 Codes
41
  CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
 
43
  AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
44
 
45
  # 2) Logit‑Mask -------------------------------------------------------
46
+ # Uses the constant EOS_TOKEN
47
  class AudioMask(LogitsProcessor):
48
  def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
49
  super().__init__()
 
50
  new_block_tensor = torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long)
51
  eos_tensor = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long)
 
 
52
  self.allow = torch.cat([new_block_tensor, audio_ids], dim=0)
53
+ self.eos = eos_tensor
54
+ self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0)
55
+ self.sent_blocks = 0
56
 
57
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
 
58
  current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
 
 
59
  mask = torch.full_like(scores, float("-inf"))
 
60
  mask[:, current_allow] = 0
 
61
  return scores + mask
62
 
63
  def reset(self):
 
64
  self.sent_blocks = 0
65
 
66
  # 3) StoppingCriteria für EOS ---------------------------------------
67
+ # Uses the constant EOS_TOKEN
68
  class EosStoppingCriteria(StoppingCriteria):
69
  def __init__(self, eos_token_id: int):
70
  self.eos_token_id = eos_token_id
71
+ # No warning needed here as we are intentionally using the constant
 
72
 
73
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
74
  if self.eos_token_id is None:
75
+ return False
 
76
  if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
77
+ print(f"StoppingCriteria: EOS detected (ID: {self.eos_token_id}).") # Add log
78
  return True
79
  return False
80
 
81
  # 4) Benutzerdefinierter AudioStreamer -------------------------------
82
  class AudioStreamer(BaseStreamer):
83
+ # Pass the constant EOS_TOKEN here too
84
  def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str, eos_token_id: int):
85
  self.ws = ws
86
  self.snac = snac_decoder
87
  self.masker = audio_mask
88
  self.loop = loop
89
  self.device = target_device
90
+ self.eos_token_id = eos_token_id # Store constant EOS ID
91
  self.buf: list[int] = []
92
  self.tasks = set()
93
 
 
99
  Maps extracted values using the structure potentially correct for Kartoffel_Orpheus.
100
  """
101
  if len(block7) != 7:
102
+ # print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
103
+ return b"" # Less verbose logging
104
 
105
  try:
106
  # --- Extract base code value (0 to CODEBOOK_SIZE-1) for each slot using modulo ---
 
120
  except IndexError:
121
  print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}")
122
  return b""
123
+ except Exception as e_map:
124
  print(f"Streamer Error: Exception during code value extraction/mapping: {e_map}. Block: {block7}")
125
  return b""
126
 
 
140
  audio = self.snac.decode(codes)[0]
141
  except Exception as e_decode:
142
  print(f"Streamer Error: Exception during snac.decode: {e_decode}")
143
+ # Add more details if needed, e.g., shapes: {[c.shape for c in codes]}
 
 
 
144
  return b""
145
 
146
  # --- Post-processing ---
 
159
  try:
160
  await self.ws.send_bytes(data)
161
  except WebSocketDisconnect:
162
+ # This is expected if client disconnects first, don't log error
163
+ # print("Streamer: WebSocket disconnected during send.")
164
+ pass
165
  except Exception as e:
166
+ if "Cannot call \"send\" once a close message has been sent" in str(e) or \
167
+ "Connection is closed" in str(e):
168
  # This is expected if client disconnects during generation, suppress repetitive logs
169
  pass
170
  else:
 
177
  """
178
  if value.numel() == 0:
179
  return
 
180
  new_token_ids = value.squeeze().cpu().tolist()
181
  if isinstance(new_token_ids, int):
182
  new_token_ids = [new_token_ids]
 
187
  self.buf.clear()
188
  continue
189
 
190
+ # Use the constant EOS_TOKEN for comparison if needed (e.g. for logging)
191
  if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
192
  self.buf.append(t - AUDIO_BASE) # Store value relative to base
193
+ # else: # Optionally log ignored tokens
194
+ # if t != self.eos_token_id: # Don't warn about the EOS token itself
195
+ # print(f"Streamer Warning: Ignoring unexpected token {t}")
196
 
197
  if len(self.buf) == 7:
198
  audio_bytes = self._decode_block(self.buf)
199
  self.buf.clear()
200
 
201
  if audio_bytes:
 
202
  future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
203
  self.tasks.add(future)
204
  future.add_done_callback(self.tasks.discard)
205
 
 
206
  if self.masker.sent_blocks == 0:
207
  self.masker.sent_blocks = 1
208
 
 
218
 
219
  @app.on_event("startup")
220
  async def load_models_startup():
221
+ # Keep global references, but EOS_TOKEN is now a constant again
222
+ global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU
223
 
224
  print(f"🚀 Starting up on device: {device}")
225
  print("⏳ Lade Modelle …", flush=True)
 
248
  print(f"Model loaded to {model.device} with dtype {model.dtype}.")
249
  model.eval()
250
 
251
+ # --- Print comparison for EOS token IDs but use the constant ---
252
  conf_eos = model.config.eos_token_id
253
  tok_eos = tok.eos_token_id
254
  print(f"Model Config EOS ID: {conf_eos}")
255
  print(f"Tokenizer EOS ID: {tok_eos}")
256
+ print(f"Using Constant EOS_TOKEN: {EOS_TOKEN}") # State the used constant
257
+ if conf_eos != EOS_TOKEN or tok_eos != EOS_TOKEN:
258
+ print(f"⚠️ WARNING: Constant EOS_TOKEN {EOS_TOKEN} differs from model/tokenizer IDs ({conf_eos}/{tok_eos}).")
259
+ # --- End EOS comparison ---
260
 
261
+ # Set pad_token_id if None (use the constant EOS)
 
 
 
 
 
 
 
 
 
262
  if model.config.pad_token_id is None:
263
+ print(f"Setting model.config.pad_token_id to Constant EOS token ID ({EOS_TOKEN})")
264
+ model.config.pad_token_id = EOS_TOKEN
 
265
 
266
  audio_ids_device = AUDIO_IDS_CPU.to(device)
267
+ # Pass the constant EOS_TOKEN to the mask
268
+ masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN)
269
  print("AudioMask initialized.")
270
 
271
+ # Pass the constant EOS_TOKEN to the stopping criteria
272
+ stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)])
273
  print("StoppingCriteria initialized.")
274
 
275
  print("✅ Modelle geladen und bereit!", flush=True)
 
296
  # 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
297
  @app.websocket("/ws/tts")
298
  async def tts(ws: WebSocket):
299
+ # No need for global actual_eos_token_id
300
  await ws.accept()
301
  print("🔌 Client connected")
302
  streamer = None
 
317
  print(f"Generating audio for: '{text}' with voice '{voice}'")
318
  ids, attn = build_prompt(text, voice)
319
  masker.reset()
320
+ # Pass the constant EOS_TOKEN to streamer
321
+ streamer = AudioStreamer(ws, snac, masker, main_loop, device, EOS_TOKEN)
322
 
323
  print("Starting generation in background thread...")
324
+ # Use sampling parameters with anti-repetition measures
325
  await asyncio.to_thread(
326
  model.generate,
327
  input_ids=ids,
328
  attention_mask=attn,
329
+ max_new_tokens=2500, # Or adjust as needed
330
  logits_processor=[masker],
331
  stopping_criteria=stopping_criteria,
332
+ # --- Sampling Parameters with Anti-Repetition ---
333
  do_sample=True,
334
+ temperature=0.6, # Adjust if needed
335
+ top_p=0.9, # Adjust if needed
336
+ repetition_penalty=1.2, # Increased (experiment!)
337
+ no_repeat_ngram_size=4, # Added (experiment!)
338
  # --- End Sampling Parameters ---
339
  use_cache=True,
340
  streamer=streamer,
341
+ eos_token_id=EOS_TOKEN # Explicitly pass constant EOS ID
342
  )
343
  print("Generation thread finished.")
344
 
 
371
  try:
372
  await ws.close(code=1000)
373
  except RuntimeError as e_close:
374
+ if "Cannot call \"send\"" not in str(e_close) and "Connection is closed" not in str(e_close):
 
375
  print(f"Runtime error closing websocket: {e_close}")
376
  except Exception as e_close_final:
377
  print(f"Error closing websocket: {e_close_final}")