Tomtom84 commited on
Commit
96dc59a
·
verified ·
1 Parent(s): 95e254f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -48
app.py CHANGED
@@ -18,6 +18,7 @@ model = None
18
  snac = None
19
  masker = None
20
  stopping_criteria = None
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
  # 0) Login + Device ---------------------------------------------------
@@ -33,7 +34,7 @@ REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
33
  # CHUNK_TOKENS = 50 # Not directly used by us with the streamer approach
34
  START_TOKEN = 128259
35
  NEW_BLOCK = 128257
36
- EOS_TOKEN = 128258
37
  AUDIO_BASE = 128266
38
  AUDIO_SPAN = 4096 * 7 # 28672 Codes
39
  CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
@@ -41,45 +42,61 @@ CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
41
  AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
42
 
43
  # 2) Logit‑Mask -------------------------------------------------------
 
44
  class AudioMask(LogitsProcessor):
45
  def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
46
  super().__init__()
 
 
 
 
47
  # Allow NEW_BLOCK and all valid audio tokens initially
48
- self.allow = torch.cat([
49
- torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long),
50
- audio_ids
51
- ], dim=0)
52
- self.eos = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long)
53
- self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0)
54
  self.sent_blocks = 0 # State: Number of audio blocks sent
55
 
56
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
 
57
  current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
 
 
58
  mask = torch.full_like(scores, float("-inf"))
 
59
  mask[:, current_allow] = 0
 
60
  return scores + mask
61
 
62
  def reset(self):
 
63
  self.sent_blocks = 0
64
 
65
  # 3) StoppingCriteria für EOS ---------------------------------------
 
66
  class EosStoppingCriteria(StoppingCriteria):
67
  def __init__(self, eos_token_id: int):
68
  self.eos_token_id = eos_token_id
 
 
69
 
70
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
 
 
71
  if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
 
72
  return True
73
  return False
74
 
75
  # 4) Benutzerdefinierter AudioStreamer -------------------------------
76
  class AudioStreamer(BaseStreamer):
77
- def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str):
78
  self.ws = ws
79
  self.snac = snac_decoder
80
  self.masker = audio_mask
81
  self.loop = loop
82
  self.device = target_device
 
83
  self.buf: list[int] = []
84
  self.tasks = set()
85
 
@@ -105,7 +122,6 @@ class AudioStreamer(BaseStreamer):
105
  code_val_6 = block7[6] % CODEBOOK_SIZE
106
 
107
  # --- Map the extracted code values to the SNAC codebooks (l1, l2, l3) ---
108
- # Using the structure from the user's previous version, believed to be correct
109
  l1 = [code_val_0]
110
  l2 = [code_val_1, code_val_4]
111
  l3 = [code_val_2, code_val_3, code_val_5, code_val_6]
@@ -130,15 +146,12 @@ class AudioStreamer(BaseStreamer):
130
  # --- Decode using SNAC ---
131
  try:
132
  with torch.no_grad():
133
- # self.snac should already be on self.device from load_models_startup
134
- audio = self.snac.decode(codes)[0] # Decode expects list of tensors, result might have batch dim
135
  except Exception as e_decode:
136
- # Add more detailed logging here if it fails again
137
  print(f"Streamer Error: Exception during snac.decode: {e_decode}")
138
  print(f"Input codes shapes: {[c.shape for c in codes]}")
139
  print(f"Input codes dtypes: {[c.dtype for c in codes]}")
140
  print(f"Input codes devices: {[c.device for c in codes]}")
141
- # Avoid printing potentially huge lists, maybe just check min/max?
142
  print(f"Input code values (min/max): L1({min(l1)}/{max(l1)}) L2({min(l2)}/{max(l2)}) L3({min(l3)}/{max(l3)})")
143
  return b""
144
 
@@ -160,7 +173,12 @@ class AudioStreamer(BaseStreamer):
160
  except WebSocketDisconnect:
161
  print("Streamer: WebSocket disconnected during send.")
162
  except Exception as e:
163
- print(f"Streamer: Error sending bytes: {e}")
 
 
 
 
 
164
 
165
  def put(self, value: torch.LongTensor):
166
  """
@@ -169,30 +187,34 @@ class AudioStreamer(BaseStreamer):
169
  """
170
  if value.numel() == 0:
171
  return
172
- new_token_ids = value.squeeze().tolist()
 
173
  if isinstance(new_token_ids, int):
174
  new_token_ids = [new_token_ids]
175
 
176
  for t in new_token_ids:
177
- if t == EOS_TOKEN:
178
- break
179
  if t == NEW_BLOCK:
180
  self.buf.clear()
181
  continue
 
182
  if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
183
  self.buf.append(t - AUDIO_BASE) # Store value relative to base
184
- # else: # Optionally log ignored tokens
185
- # print(f"Streamer Warning: Ignoring unexpected token {t}")
 
186
 
187
  if len(self.buf) == 7:
188
  audio_bytes = self._decode_block(self.buf)
189
  self.buf.clear()
190
 
191
  if audio_bytes:
 
192
  future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
193
  self.tasks.add(future)
194
  future.add_done_callback(self.tasks.discard)
195
 
 
196
  if self.masker.sent_blocks == 0:
197
  self.masker.sent_blocks = 1
198
 
@@ -201,7 +223,6 @@ class AudioStreamer(BaseStreamer):
201
  if len(self.buf) > 0:
202
  print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
203
  self.buf.clear()
204
- # print(f"Streamer: Generation finished. Pending send tasks: {len(self.tasks)}")
205
  pass
206
 
207
  # 5) FastAPI App ------------------------------------------------------
@@ -209,7 +230,7 @@ app = FastAPI()
209
 
210
  @app.on_event("startup")
211
  async def load_models_startup():
212
- global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU
213
 
214
  print(f"🚀 Starting up on device: {device}")
215
  print("⏳ Lade Modelle …", flush=True)
@@ -218,7 +239,7 @@ async def load_models_startup():
218
  print("Tokenizer loaded.")
219
 
220
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
221
- print(f"SNAC loaded to {device}.") # Use the global device variable
222
 
223
  model_dtype = torch.float32
224
  if device == "cuda":
@@ -235,25 +256,40 @@ async def load_models_startup():
235
  torch_dtype=model_dtype,
236
  low_cpu_mem_usage=True,
237
  )
238
- model.config.pad_token_id = model.config.eos_token_id
239
  print(f"Model loaded to {model.device} with dtype {model.dtype}.")
240
  model.eval()
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  audio_ids_device = AUDIO_IDS_CPU.to(device)
243
- masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN)
 
244
  print("AudioMask initialized.")
245
 
246
- stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)])
 
247
  print("StoppingCriteria initialized.")
248
 
249
  print("✅ Modelle geladen und bereit!", flush=True)
250
- print(f"Tokenizer EOS ID: {tok.eos_token_id}")
251
- print(f"Model Config EOS ID: {model.config.eos_token_id}")
252
- print(f"Constant EOS_TOKEN: {EOS_TOKEN}")
253
- if tok.eos_token_id != EOS_TOKEN or model.config.eos_token_id != EOS_TOKEN:
254
- print("⚠️ WARNING: EOS_TOKEN constant might not match model/tokenizer configuration!")
255
- # Consider updating EOS_TOKEN if they differ, e.g.:
256
- # EOS_TOKEN = model.config.eos_token_id
257
 
258
  @app.get("/")
259
  def hello():
@@ -277,6 +313,7 @@ def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]:
277
  # 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
278
  @app.websocket("/ws/tts")
279
  async def tts(ws: WebSocket):
 
280
  await ws.accept()
281
  print("🔌 Client connected")
282
  streamer = None
@@ -297,24 +334,27 @@ async def tts(ws: WebSocket):
297
  print(f"Generating audio for: '{text}' with voice '{voice}'")
298
  ids, attn = build_prompt(text, voice)
299
  masker.reset()
300
- streamer = AudioStreamer(ws, snac, masker, main_loop, device)
 
301
 
302
  print("Starting generation in background thread...")
 
303
  await asyncio.to_thread(
304
- model.generate,
305
- input_ids=ids,
306
- attention_mask=attn,
307
- max_new_tokens=2500, # Keep or increase later if needed
308
- logits_processor=[masker],
309
- stopping_criteria=stopping_criteria,
310
- # --- Changes ---
311
- do_sample=True, # Enable sampling
312
- temperature=0.6, # Introduce some randomness (adjust as needed)
313
- top_p=0.9, # Focus sampling on more likely tokens (adjust as needed)
314
- repetition_penalty=1.15, # Penalize recently generated tokens (adjust > 1.0)
315
- # --- End Changes ---
316
- use_cache=True,
317
- streamer=streamer
 
318
  )
319
  print("Generation thread finished.")
320
 
@@ -347,7 +387,9 @@ async def tts(ws: WebSocket):
347
  try:
348
  await ws.close(code=1000)
349
  except RuntimeError as e_close:
350
- print(f"Runtime error closing websocket: {e_close}")
 
 
351
  except Exception as e_close_final:
352
  print(f"Error closing websocket: {e_close_final}")
353
  elif ws.client_state.name != "DISCONNECTED":
 
18
  snac = None
19
  masker = None
20
  stopping_criteria = None
21
+ actual_eos_token_id = None # Will be determined during startup
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
  # 0) Login + Device ---------------------------------------------------
 
34
  # CHUNK_TOKENS = 50 # Not directly used by us with the streamer approach
35
  START_TOKEN = 128259
36
  NEW_BLOCK = 128257
37
+ # EOS_TOKEN = 128258 # REMOVED - Will be determined from model/tokenizer config
38
  AUDIO_BASE = 128266
39
  AUDIO_SPAN = 4096 * 7 # 28672 Codes
40
  CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
 
42
  AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
43
 
44
  # 2) Logit‑Mask -------------------------------------------------------
45
+ # Uses the dynamically determined EOS token ID
46
  class AudioMask(LogitsProcessor):
47
  def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
48
  super().__init__()
49
+ # Ensure input tensors are Long type for concatenation if needed, although indices are usually int
50
+ new_block_tensor = torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long)
51
+ eos_tensor = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long)
52
+
53
  # Allow NEW_BLOCK and all valid audio tokens initially
54
+ self.allow = torch.cat([new_block_tensor, audio_ids], dim=0)
55
+ self.eos = eos_tensor # Store EOS token ID as tensor
56
+ self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0) # Precompute combined tensor
 
 
 
57
  self.sent_blocks = 0 # State: Number of audio blocks sent
58
 
59
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
60
+ # Determine which tokens are allowed based on whether blocks have been sent
61
  current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
62
+
63
+ # Create a mask initialized to negative infinity
64
  mask = torch.full_like(scores, float("-inf"))
65
+ # Set allowed token scores to 0 (effectively allowing them)
66
  mask[:, current_allow] = 0
67
+ # Apply the mask to the scores
68
  return scores + mask
69
 
70
  def reset(self):
71
+ """Resets the state for a new generation request."""
72
  self.sent_blocks = 0
73
 
74
  # 3) StoppingCriteria für EOS ---------------------------------------
75
+ # Uses the dynamically determined EOS token ID
76
  class EosStoppingCriteria(StoppingCriteria):
77
  def __init__(self, eos_token_id: int):
78
  self.eos_token_id = eos_token_id
79
+ if self.eos_token_id is None:
80
+ print("⚠️ EosStoppingCriteria initialized with eos_token_id=None!")
81
 
82
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
83
+ if self.eos_token_id is None:
84
+ return False # Cannot stop if EOS ID is unknown
85
+ # Check if the *last* generated token is the EOS token
86
  if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
87
+ # print("StoppingCriteria: EOS detected.")
88
  return True
89
  return False
90
 
91
  # 4) Benutzerdefinierter AudioStreamer -------------------------------
92
  class AudioStreamer(BaseStreamer):
93
+ def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str, eos_token_id: int):
94
  self.ws = ws
95
  self.snac = snac_decoder
96
  self.masker = audio_mask
97
  self.loop = loop
98
  self.device = target_device
99
+ self.eos_token_id = eos_token_id # Store EOS ID for potential use in put (optional)
100
  self.buf: list[int] = []
101
  self.tasks = set()
102
 
 
122
  code_val_6 = block7[6] % CODEBOOK_SIZE
123
 
124
  # --- Map the extracted code values to the SNAC codebooks (l1, l2, l3) ---
 
125
  l1 = [code_val_0]
126
  l2 = [code_val_1, code_val_4]
127
  l3 = [code_val_2, code_val_3, code_val_5, code_val_6]
 
146
  # --- Decode using SNAC ---
147
  try:
148
  with torch.no_grad():
149
+ audio = self.snac.decode(codes)[0]
 
150
  except Exception as e_decode:
 
151
  print(f"Streamer Error: Exception during snac.decode: {e_decode}")
152
  print(f"Input codes shapes: {[c.shape for c in codes]}")
153
  print(f"Input codes dtypes: {[c.dtype for c in codes]}")
154
  print(f"Input codes devices: {[c.device for c in codes]}")
 
155
  print(f"Input code values (min/max): L1({min(l1)}/{max(l1)}) L2({min(l2)}/{max(l2)}) L3({min(l3)}/{max(l3)})")
156
  return b""
157
 
 
173
  except WebSocketDisconnect:
174
  print("Streamer: WebSocket disconnected during send.")
175
  except Exception as e:
176
+ # Handle cases where sending fails after connection closed
177
+ if "Cannot call \"send\" once a close message has been sent" in str(e):
178
+ # This is expected if client disconnects during generation, suppress repetitive logs
179
+ pass
180
+ else:
181
+ print(f"Streamer: Error sending bytes: {e}")
182
 
183
  def put(self, value: torch.LongTensor):
184
  """
 
187
  """
188
  if value.numel() == 0:
189
  return
190
+ # Ensure value is on CPU and flatten to a list of ints
191
+ new_token_ids = value.squeeze().cpu().tolist()
192
  if isinstance(new_token_ids, int):
193
  new_token_ids = [new_token_ids]
194
 
195
  for t in new_token_ids:
196
+ # No need to check for EOS here, StoppingCriteria handles it
 
197
  if t == NEW_BLOCK:
198
  self.buf.clear()
199
  continue
200
+
201
  if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
202
  self.buf.append(t - AUDIO_BASE) # Store value relative to base
203
+ # else: # Optionally log ignored tokens outside audio range
204
+ # if t != self.eos_token_id: # Don't warn about the EOS token itself
205
+ # print(f"Streamer Warning: Ignoring unexpected token {t}")
206
 
207
  if len(self.buf) == 7:
208
  audio_bytes = self._decode_block(self.buf)
209
  self.buf.clear()
210
 
211
  if audio_bytes:
212
+ # Schedule the async send function to run on the main event loop
213
  future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
214
  self.tasks.add(future)
215
  future.add_done_callback(self.tasks.discard)
216
 
217
+ # Allow EOS only after the first full block has been processed
218
  if self.masker.sent_blocks == 0:
219
  self.masker.sent_blocks = 1
220
 
 
223
  if len(self.buf) > 0:
224
  print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
225
  self.buf.clear()
 
226
  pass
227
 
228
  # 5) FastAPI App ------------------------------------------------------
 
230
 
231
  @app.on_event("startup")
232
  async def load_models_startup():
233
+ global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU, actual_eos_token_id
234
 
235
  print(f"🚀 Starting up on device: {device}")
236
  print("⏳ Lade Modelle …", flush=True)
 
239
  print("Tokenizer loaded.")
240
 
241
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
242
+ print(f"SNAC loaded to {device}.")
243
 
244
  model_dtype = torch.float32
245
  if device == "cuda":
 
256
  torch_dtype=model_dtype,
257
  low_cpu_mem_usage=True,
258
  )
 
259
  print(f"Model loaded to {model.device} with dtype {model.dtype}.")
260
  model.eval()
261
 
262
+ # --- Determine and set the correct EOS token ID ---
263
+ conf_eos = model.config.eos_token_id
264
+ tok_eos = tok.eos_token_id
265
+ print(f"Model Config EOS ID: {conf_eos}")
266
+ print(f"Tokenizer EOS ID: {tok_eos}")
267
+
268
+ if conf_eos is not None:
269
+ actual_eos_token_id = conf_eos
270
+ elif tok_eos is not None:
271
+ actual_eos_token_id = tok_eos
272
+ print(f"⚠️ Model config EOS ID is None, using Tokenizer EOS ID: {actual_eos_token_id}")
273
+ else:
274
+ raise ValueError("Could not determine EOS token ID from model config or tokenizer.")
275
+
276
+ print(f"Using EOS Token ID: {actual_eos_token_id}")
277
+ # Set pad_token_id to eos_token_id if not already set (common practice for generation)
278
+ if model.config.pad_token_id is None:
279
+ print(f"Setting model.config.pad_token_id to EOS token ID ({actual_eos_token_id})")
280
+ model.config.pad_token_id = actual_eos_token_id
281
+ # --- End EOS Token ID determination ---
282
+
283
  audio_ids_device = AUDIO_IDS_CPU.to(device)
284
+ # Pass the determined EOS ID to the mask
285
+ masker = AudioMask(audio_ids_device, NEW_BLOCK, actual_eos_token_id)
286
  print("AudioMask initialized.")
287
 
288
+ # Pass the determined EOS ID to the stopping criteria
289
+ stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(actual_eos_token_id)])
290
  print("StoppingCriteria initialized.")
291
 
292
  print("✅ Modelle geladen und bereit!", flush=True)
 
 
 
 
 
 
 
293
 
294
  @app.get("/")
295
  def hello():
 
313
  # 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
314
  @app.websocket("/ws/tts")
315
  async def tts(ws: WebSocket):
316
+ global actual_eos_token_id # Ensure we can access the determined EOS ID
317
  await ws.accept()
318
  print("🔌 Client connected")
319
  streamer = None
 
334
  print(f"Generating audio for: '{text}' with voice '{voice}'")
335
  ids, attn = build_prompt(text, voice)
336
  masker.reset()
337
+ # Pass the determined EOS ID to the streamer as well (optional, for logging/checks)
338
+ streamer = AudioStreamer(ws, snac, masker, main_loop, device, actual_eos_token_id)
339
 
340
  print("Starting generation in background thread...")
341
+ # Use sampling parameters to avoid repetition
342
  await asyncio.to_thread(
343
+ model.generate,
344
+ input_ids=ids,
345
+ attention_mask=attn,
346
+ max_new_tokens=2500, # Increased slightly, adjust as needed
347
+ logits_processor=[masker],
348
+ stopping_criteria=stopping_criteria,
349
+ # --- Sampling Parameters ---
350
+ do_sample=True,
351
+ temperature=0.6,
352
+ top_p=0.9,
353
+ repetition_penalty=1.15,
354
+ # --- End Sampling Parameters ---
355
+ use_cache=True,
356
+ streamer=streamer,
357
+ eos_token_id=actual_eos_token_id # Explicitly pass correct EOS ID here too
358
  )
359
  print("Generation thread finished.")
360
 
 
387
  try:
388
  await ws.close(code=1000)
389
  except RuntimeError as e_close:
390
+ # Suppress "Cannot call 'send'..." error during final close if already disconnected
391
+ if "Cannot call \"send\"" not in str(e_close):
392
+ print(f"Runtime error closing websocket: {e_close}")
393
  except Exception as e_close_final:
394
  print(f"Error closing websocket: {e_close_final}")
395
  elif ws.client_state.name != "DISCONNECTED":