Tomtom84 commited on
Commit
53012c3
·
verified ·
1 Parent(s): 55145d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -146
app.py CHANGED
@@ -36,6 +36,7 @@ NEW_BLOCK = 128257
36
  EOS_TOKEN = 128258
37
  AUDIO_BASE = 128266
38
  AUDIO_SPAN = 4096 * 7 # 28672 Codes
 
39
  # Create AUDIO_IDS on the correct device later in load_models
40
  AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
41
 
@@ -45,113 +46,117 @@ class AudioMask(LogitsProcessor):
45
  super().__init__()
46
  # Allow NEW_BLOCK and all valid audio tokens initially
47
  self.allow = torch.cat([
48
- torch.tensor([new_block_token_id], device=audio_ids.device), # Add NEW_BLOCK token ID
49
  audio_ids
50
  ], dim=0)
51
- self.eos = torch.tensor([eos_token_id], device=audio_ids.device) # Store EOS token ID as tensor
52
- self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0) # Precompute combined tensor
53
  self.sent_blocks = 0 # State: Number of audio blocks sent
54
 
55
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
56
- # Determine which tokens are allowed based on whether blocks have been sent
57
  current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
58
-
59
- # Create a mask initialized to negative infinity
60
  mask = torch.full_like(scores, float("-inf"))
61
- # Set allowed token scores to 0 (effectively allowing them)
62
  mask[:, current_allow] = 0
63
- # Apply the mask to the scores
64
  return scores + mask
65
 
66
  def reset(self):
67
- """Resets the state for a new generation request."""
68
  self.sent_blocks = 0
69
 
70
  # 3) StoppingCriteria für EOS ---------------------------------------
71
- # generate() needs explicit stopping criteria when using a streamer
72
  class EosStoppingCriteria(StoppingCriteria):
73
  def __init__(self, eos_token_id: int):
74
  self.eos_token_id = eos_token_id
75
 
76
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
77
- # Check if the *last* generated token is the EOS token
78
  if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
79
- # print("StoppingCriteria: EOS detected.")
80
  return True
81
  return False
82
 
83
  # 4) Benutzerdefinierter AudioStreamer -------------------------------
84
  class AudioStreamer(BaseStreamer):
85
- # --- Updated __init__ to accept target_device ---
86
  def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str):
87
  self.ws = ws
88
  self.snac = snac_decoder
89
- self.masker = audio_mask # Reference to the mask to update sent_blocks
90
- self.loop = loop # Event loop of the main thread for run_coroutine_threadsafe
91
- # --- Use the passed target_device ---
92
  self.device = target_device
93
- self.buf: list[int] = [] # Buffer for audio token values (AUDIO_BASE subtracted)
94
- self.tasks = set() # Keep track of pending send tasks
95
 
96
  def _decode_block(self, block7: list[int]) -> bytes:
97
  """
98
  Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
99
- NOTE: The mapping from the 7 tokens to the 3 SNAC codebooks (l1, l2, l3)
100
- is based on the structure found in the previous while-loop version.
101
- If audio is distorted, this mapping is the primary suspect.
102
- Ensure this mapping is correct for the specific model!
103
  """
104
  if len(block7) != 7:
105
  print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
106
- return b"" # Return empty bytes if block is incomplete
107
 
108
- # --- Mapping derived from previous user version (indices [0], [1,4], [2,3,5,6]) ---
109
- # This seems more likely to be correct for Kartoffel_Orpheus if the previous version worked.
110
  try:
111
- l1 = [block7[0]] # Index 0
112
- l2 = [block7[1], block7[4]] # Indices 1, 4
113
- l3 = [block7[2], block7[3], block7[5], block7[6]] # Indices 2, 3, 5, 6
 
 
 
 
 
 
 
 
 
 
 
 
114
  except IndexError:
115
  print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}")
116
  return b""
 
 
 
117
 
118
- # --- Alternative Hypothesis (commented out): Interleaving mapping ---
119
- # try:
120
- # l1 = [block7[0], block7[3], block7[6]] # Codebook 1 indices: 0, 3, 6
121
- # l2 = [block7[1], block7[4]] # Codebook 2 indices: 1, 4
122
- # l3 = [block7[2], block7[5]] # Codebook 3 indices: 2, 5
123
- # except IndexError:
124
- # print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}")
125
- # return b""
126
- # --- End Alternative Hypothesis ---
127
-
128
-
129
- # Convert lists to tensors on the correct device
130
- # Use self.device which was set correctly in __init__
131
- codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0)
132
- codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0)
133
- codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0)
134
- codes = [codes_l1, codes_l2, codes_l3] # List of tensors for SNAC
135
-
136
- # Decode using SNAC
137
- with torch.no_grad():
138
- # self.snac should already be on self.device from load_models_startup
139
- audio = self.snac.decode(codes)[0] # Decode expects list of tensors, result might have batch dim
140
 
141
- # Squeeze, move to CPU, convert to numpy
142
- audio_np = audio.squeeze().detach().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- # Convert to 16-bit PCM bytes
145
- audio_bytes = (audio_np * 32767).astype("int16").tobytes()
146
- return audio_bytes
 
 
 
 
 
147
 
148
  async def _send_audio_bytes(self, data: bytes):
149
  """Coroutine to send bytes over WebSocket."""
150
- if not data: # Don't send empty bytes
151
  return
152
  try:
153
  await self.ws.send_bytes(data)
154
- # print(f"Streamer: Sent {len(data)} audio bytes.")
155
  except WebSocketDisconnect:
156
  print("Streamer: WebSocket disconnected during send.")
157
  except Exception as e:
@@ -159,67 +164,43 @@ class AudioStreamer(BaseStreamer):
159
 
160
  def put(self, value: torch.LongTensor):
161
  """
162
- Receives new token IDs (Tensor) from generate() (runs in worker thread).
163
- Processes tokens, decodes full blocks, and schedules sending via run_coroutine_threadsafe.
164
  """
165
- # Ensure value is on CPU and flatten to a list of ints
166
  if value.numel() == 0:
167
  return
168
  new_token_ids = value.squeeze().tolist()
169
- if isinstance(new_token_ids, int): # Handle single token case
170
  new_token_ids = [new_token_ids]
171
 
172
  for t in new_token_ids:
173
  if t == EOS_TOKEN:
174
- # print("Streamer: EOS token encountered.")
175
- # EOS is handled by StoppingCriteria, no action needed here except maybe logging.
176
- break # Stop processing this batch if EOS is found
177
-
178
  if t == NEW_BLOCK:
179
- # print("Streamer: NEW_BLOCK token encountered.")
180
- # NEW_BLOCK indicates the start of audio, might reset buffer if needed
181
  self.buf.clear()
182
- continue # Move to the next token
183
-
184
- # Check if token is within the expected audio range
185
  if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
186
- # Store value relative to base (IMPORTANT for _decode_block)
187
- self.buf.append(t - AUDIO_BASE)
188
- else:
189
- # Log unexpected tokens (like START_TOKEN or others if generation goes wrong)
190
  # print(f"Streamer Warning: Ignoring unexpected token {t}")
191
- pass # Ignore tokens outside the audio range
192
 
193
- # If buffer has 7 tokens, decode and send
194
  if len(self.buf) == 7:
195
  audio_bytes = self._decode_block(self.buf)
196
- self.buf.clear() # Clear buffer after processing
197
 
198
- if audio_bytes: # Only send if decoding was successful
199
- # Schedule the async send function to run on the main event loop
200
  future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
201
  self.tasks.add(future)
202
- # Optional: Remove completed tasks to prevent memory leak if generation is very long
203
  future.add_done_callback(self.tasks.discard)
204
 
205
- # Allow EOS only after the first full block has been processed and scheduled for sending
206
  if self.masker.sent_blocks == 0:
207
- # print("Streamer: First audio block processed, allowing EOS.")
208
- self.masker.sent_blocks = 1 # Update state in the mask
209
-
210
- # Note: No need to explicitly wait for tasks here. put() should return quickly.
211
 
212
  def end(self):
213
  """Called by generate() when generation finishes."""
214
- # Handle any remaining tokens in the buffer (optional, here we discard them)
215
  if len(self.buf) > 0:
216
  print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
217
  self.buf.clear()
218
-
219
- # Optional: Wait briefly for any outstanding send tasks to complete?
220
- # This is tricky because end() is sync. A robust solution might involve
221
- # signaling the WebSocket handler to wait before closing.
222
- # For simplicity, we rely on FastAPI/Uvicorn's graceful shutdown handling.
223
  # print(f"Streamer: Generation finished. Pending send tasks: {len(self.tasks)}")
224
  pass
225
 
@@ -227,7 +208,7 @@ class AudioStreamer(BaseStreamer):
227
  app = FastAPI()
228
 
229
  @app.on_event("startup")
230
- async def load_models_startup(): # Make startup async if needed for future async loads
231
  global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU
232
 
233
  print(f"🚀 Starting up on device: {device}")
@@ -236,41 +217,32 @@ async def load_models_startup(): # Make startup async if needed for future async
236
  tok = AutoTokenizer.from_pretrained(REPO)
237
  print("Tokenizer loaded.")
238
 
239
- # Load SNAC first (usually smaller)
240
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
241
- # --- FIXED Print statement ---
242
  print(f"SNAC loaded to {device}.") # Use the global device variable
243
 
244
- # Load the main model
245
- # Determine appropriate dtype based on device and support
246
- model_dtype = torch.float32 # Default to float32 for CPU
247
  if device == "cuda":
248
  if torch.cuda.is_bf16_supported():
249
  model_dtype = torch.bfloat16
250
  print("Using bfloat16 for model.")
251
  else:
252
- model_dtype = torch.float16 # Fallback to float16 if bfloat16 not supported
253
  print("Using float16 for model.")
254
 
255
  model = AutoModelForCausalLM.from_pretrained(
256
  REPO,
257
- device_map={"": 0} if device == "cuda" else None, # Assign to GPU 0 if cuda
258
  torch_dtype=model_dtype,
259
- low_cpu_mem_usage=True, # Good practice for large models
260
  )
261
- model.config.pad_token_id = model.config.eos_token_id # Set pad token
262
  print(f"Model loaded to {model.device} with dtype {model.dtype}.")
263
-
264
- # Ensure model is in evaluation mode
265
  model.eval()
266
 
267
- # Initialize AudioMask (needs AUDIO_IDS on the correct device)
268
  audio_ids_device = AUDIO_IDS_CPU.to(device)
269
  masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN)
270
  print("AudioMask initialized.")
271
 
272
- # Initialize StoppingCriteria
273
- # IMPORTANT: Create the list and add the criteria instance
274
  stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)])
275
  print("StoppingCriteria initialized.")
276
 
@@ -283,18 +255,15 @@ def hello():
283
  # 6) Helper zum Prompt Bauen -------------------------------------------
284
  def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]:
285
  """Builds the input_ids and attention_mask for the model."""
286
- # Format: <START> <VOICE>: <TEXT> <NEW_BLOCK>
287
  prompt_text = f"{voice}: {text}"
288
  prompt_ids = tok(prompt_text, return_tensors="pt").input_ids.to(device)
289
 
290
- # Construct input_ids tensor
291
  input_ids = torch.cat([
292
- torch.tensor([[START_TOKEN]], device=device, dtype=torch.long), # Start token
293
- prompt_ids, # Encoded prompt
294
- torch.tensor([[NEW_BLOCK]], device=device, dtype=torch.long) # New block token to trigger audio
295
  ], dim=1)
296
 
297
- # Create attention mask (all ones)
298
  attention_mask = torch.ones_like(input_ids)
299
  return input_ids, attention_mask
300
 
@@ -303,49 +272,37 @@ def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]:
303
  async def tts(ws: WebSocket):
304
  await ws.accept()
305
  print("🔌 Client connected")
306
- streamer = None # Initialize for finally block
307
- main_loop = asyncio.get_running_loop() # Get the current event loop
308
 
309
  try:
310
- # Receive configuration
311
  req_text = await ws.receive_text()
312
  print(f"Received request: {req_text}")
313
  req = json.loads(req_text)
314
- text = req.get("text", "Hallo Welt, wie geht es dir heute?") # Default text
315
- voice = req.get("voice", "Jakob") # Default voice
316
 
317
  if not text:
318
  print("⚠️ Request text is empty.")
319
- await ws.close(code=1003, reason="Text cannot be empty") # 1003 = Cannot accept data type
320
  return
321
 
322
  print(f"Generating audio for: '{text}' with voice '{voice}'")
323
-
324
- # Prepare prompt
325
  ids, attn = build_prompt(text, voice)
326
-
327
- # --- Reset stateful components ---
328
- masker.reset() # CRITICAL: Reset the mask state for the new request
329
-
330
- # --- Create Streamer Instance ---
331
- # --- Pass the global 'device' variable ---
332
  streamer = AudioStreamer(ws, snac, masker, main_loop, device)
333
 
334
- # --- Run model.generate in a separate thread ---
335
- # This prevents blocking the main FastAPI event loop
336
  print("Starting generation in background thread...")
337
  await asyncio.to_thread(
338
  model.generate,
339
  input_ids=ids,
340
  attention_mask=attn,
341
- max_new_tokens=1500, # Limit generation length (adjust as needed)
342
  logits_processor=[masker],
343
  stopping_criteria=stopping_criteria,
344
- do_sample=False, # Use greedy decoding for potentially more stable audio
345
- # do_sample=True, temperature=0.7, top_p=0.95, # Or use sampling
346
  use_cache=True,
347
- streamer=streamer # Pass the custom streamer
348
- # No need to manage past_key_values manually
349
  )
350
  print("Generation thread finished.")
351
 
@@ -358,32 +315,26 @@ async def tts(ws: WebSocket):
358
  except Exception as e:
359
  error_details = traceback.format_exc()
360
  print(f"❌ WS‑Error: {e}\n{error_details}", flush=True)
361
- # Try to send an error message before closing, if possible
362
  error_payload = json.dumps({"error": str(e)})
363
  try:
364
  if ws.client_state.name == "CONNECTED":
365
- await ws.send_text(error_payload) # Send error as text/json
366
  except Exception:
367
- pass # Ignore error during error reporting
368
- # Close with internal server error code
369
  if ws.client_state.name == "CONNECTED":
370
- await ws.close(code=1011) # 1011 = Internal Server Error
371
  finally:
372
- # Ensure streamer's end method is called if it exists
373
  if streamer:
374
  try:
375
- # print("Calling streamer.end()")
376
  streamer.end()
377
  except Exception as e_end:
378
  print(f"Error during streamer.end(): {e_end}")
379
 
380
- # Ensure WebSocket is closed
381
  print("Closing connection.")
382
  if ws.client_state.name == "CONNECTED":
383
  try:
384
- await ws.close(code=1000) # 1000 = Normal Closure
385
  except RuntimeError as e_close:
386
- # Can happen if connection is already closing/closed
387
  print(f"Runtime error closing websocket: {e_close}")
388
  except Exception as e_close_final:
389
  print(f"Error closing websocket: {e_close_final}")
@@ -395,6 +346,4 @@ async def tts(ws: WebSocket):
395
  if __name__ == "__main__":
396
  import uvicorn
397
  print("Starting Uvicorn server...")
398
- # Use reload=True only for development, remove for production
399
- # Consider adding --workers 1 if you experience issues with multiple workers and global state/GPU memory
400
- uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info") #, reload=True)
 
36
  EOS_TOKEN = 128258
37
  AUDIO_BASE = 128266
38
  AUDIO_SPAN = 4096 * 7 # 28672 Codes
39
+ CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
40
  # Create AUDIO_IDS on the correct device later in load_models
41
  AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
42
 
 
46
  super().__init__()
47
  # Allow NEW_BLOCK and all valid audio tokens initially
48
  self.allow = torch.cat([
49
+ torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long),
50
  audio_ids
51
  ], dim=0)
52
+ self.eos = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long)
53
+ self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0)
54
  self.sent_blocks = 0 # State: Number of audio blocks sent
55
 
56
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
 
57
  current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
 
 
58
  mask = torch.full_like(scores, float("-inf"))
 
59
  mask[:, current_allow] = 0
 
60
  return scores + mask
61
 
62
  def reset(self):
 
63
  self.sent_blocks = 0
64
 
65
  # 3) StoppingCriteria für EOS ---------------------------------------
 
66
  class EosStoppingCriteria(StoppingCriteria):
67
  def __init__(self, eos_token_id: int):
68
  self.eos_token_id = eos_token_id
69
 
70
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
71
  if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
 
72
  return True
73
  return False
74
 
75
  # 4) Benutzerdefinierter AudioStreamer -------------------------------
76
  class AudioStreamer(BaseStreamer):
 
77
  def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str):
78
  self.ws = ws
79
  self.snac = snac_decoder
80
+ self.masker = audio_mask
81
+ self.loop = loop
 
82
  self.device = target_device
83
+ self.buf: list[int] = []
84
+ self.tasks = set()
85
 
86
  def _decode_block(self, block7: list[int]) -> bytes:
87
  """
88
  Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
89
+ NOTE: Extracts base code value (0-4095) using modulo, assuming
90
+ input values represent (slot_offset + code_value).
91
+ Maps extracted values using the structure potentially correct for Kartoffel_Orpheus.
 
92
  """
93
  if len(block7) != 7:
94
  print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
95
+ return b""
96
 
 
 
97
  try:
98
+ # --- Extract base code value (0 to CODEBOOK_SIZE-1) for each slot using modulo ---
99
+ code_val_0 = block7[0] % CODEBOOK_SIZE
100
+ code_val_1 = block7[1] % CODEBOOK_SIZE
101
+ code_val_2 = block7[2] % CODEBOOK_SIZE
102
+ code_val_3 = block7[3] % CODEBOOK_SIZE
103
+ code_val_4 = block7[4] % CODEBOOK_SIZE
104
+ code_val_5 = block7[5] % CODEBOOK_SIZE
105
+ code_val_6 = block7[6] % CODEBOOK_SIZE
106
+
107
+ # --- Map the extracted code values to the SNAC codebooks (l1, l2, l3) ---
108
+ # Using the structure from the user's previous version, believed to be correct
109
+ l1 = [code_val_0]
110
+ l2 = [code_val_1, code_val_4]
111
+ l3 = [code_val_2, code_val_3, code_val_5, code_val_6]
112
+
113
  except IndexError:
114
  print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}")
115
  return b""
116
+ except Exception as e_map: # Catch potential issues with modulo/mapping
117
+ print(f"Streamer Error: Exception during code value extraction/mapping: {e_map}. Block: {block7}")
118
+ return b""
119
 
120
+ # --- Convert lists to tensors on the correct device ---
121
+ try:
122
+ codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0)
123
+ codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0)
124
+ codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0)
125
+ codes = [codes_l1, codes_l2, codes_l3]
126
+ except Exception as e_tensor:
127
+ print(f"Streamer Error: Exception during tensor conversion: {e_tensor}. l1={l1}, l2={l2}, l3={l3}")
128
+ return b""
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ # --- Decode using SNAC ---
131
+ try:
132
+ with torch.no_grad():
133
+ # self.snac should already be on self.device from load_models_startup
134
+ audio = self.snac.decode(codes)[0] # Decode expects list of tensors, result might have batch dim
135
+ except Exception as e_decode:
136
+ # Add more detailed logging here if it fails again
137
+ print(f"Streamer Error: Exception during snac.decode: {e_decode}")
138
+ print(f"Input codes shapes: {[c.shape for c in codes]}")
139
+ print(f"Input codes dtypes: {[c.dtype for c in codes]}")
140
+ print(f"Input codes devices: {[c.device for c in codes]}")
141
+ # Avoid printing potentially huge lists, maybe just check min/max?
142
+ print(f"Input code values (min/max): L1({min(l1)}/{max(l1)}) L2({min(l2)}/{max(l2)}) L3({min(l3)}/{max(l3)})")
143
+ return b""
144
 
145
+ # --- Post-processing ---
146
+ try:
147
+ audio_np = audio.squeeze().detach().cpu().numpy()
148
+ audio_bytes = (audio_np * 32767).astype("int16").tobytes()
149
+ return audio_bytes
150
+ except Exception as e_post:
151
+ print(f"Streamer Error: Exception during post-processing: {e_post}. Audio tensor shape: {audio.shape}")
152
+ return b""
153
 
154
  async def _send_audio_bytes(self, data: bytes):
155
  """Coroutine to send bytes over WebSocket."""
156
+ if not data:
157
  return
158
  try:
159
  await self.ws.send_bytes(data)
 
160
  except WebSocketDisconnect:
161
  print("Streamer: WebSocket disconnected during send.")
162
  except Exception as e:
 
164
 
165
  def put(self, value: torch.LongTensor):
166
  """
167
+ Receives new token IDs (Tensor) from generate().
168
+ Processes tokens, decodes full blocks, and schedules sending.
169
  """
 
170
  if value.numel() == 0:
171
  return
172
  new_token_ids = value.squeeze().tolist()
173
+ if isinstance(new_token_ids, int):
174
  new_token_ids = [new_token_ids]
175
 
176
  for t in new_token_ids:
177
  if t == EOS_TOKEN:
178
+ break
 
 
 
179
  if t == NEW_BLOCK:
 
 
180
  self.buf.clear()
181
+ continue
 
 
182
  if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
183
+ self.buf.append(t - AUDIO_BASE) # Store value relative to base
184
+ # else: # Optionally log ignored tokens
 
 
185
  # print(f"Streamer Warning: Ignoring unexpected token {t}")
 
186
 
 
187
  if len(self.buf) == 7:
188
  audio_bytes = self._decode_block(self.buf)
189
+ self.buf.clear()
190
 
191
+ if audio_bytes:
 
192
  future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
193
  self.tasks.add(future)
 
194
  future.add_done_callback(self.tasks.discard)
195
 
 
196
  if self.masker.sent_blocks == 0:
197
+ self.masker.sent_blocks = 1
 
 
 
198
 
199
  def end(self):
200
  """Called by generate() when generation finishes."""
 
201
  if len(self.buf) > 0:
202
  print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
203
  self.buf.clear()
 
 
 
 
 
204
  # print(f"Streamer: Generation finished. Pending send tasks: {len(self.tasks)}")
205
  pass
206
 
 
208
  app = FastAPI()
209
 
210
  @app.on_event("startup")
211
+ async def load_models_startup():
212
  global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU
213
 
214
  print(f"🚀 Starting up on device: {device}")
 
217
  tok = AutoTokenizer.from_pretrained(REPO)
218
  print("Tokenizer loaded.")
219
 
 
220
  snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
 
221
  print(f"SNAC loaded to {device}.") # Use the global device variable
222
 
223
+ model_dtype = torch.float32
 
 
224
  if device == "cuda":
225
  if torch.cuda.is_bf16_supported():
226
  model_dtype = torch.bfloat16
227
  print("Using bfloat16 for model.")
228
  else:
229
+ model_dtype = torch.float16
230
  print("Using float16 for model.")
231
 
232
  model = AutoModelForCausalLM.from_pretrained(
233
  REPO,
234
+ device_map={"": 0} if device == "cuda" else None,
235
  torch_dtype=model_dtype,
236
+ low_cpu_mem_usage=True,
237
  )
238
+ model.config.pad_token_id = model.config.eos_token_id
239
  print(f"Model loaded to {model.device} with dtype {model.dtype}.")
 
 
240
  model.eval()
241
 
 
242
  audio_ids_device = AUDIO_IDS_CPU.to(device)
243
  masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN)
244
  print("AudioMask initialized.")
245
 
 
 
246
  stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)])
247
  print("StoppingCriteria initialized.")
248
 
 
255
  # 6) Helper zum Prompt Bauen -------------------------------------------
256
  def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]:
257
  """Builds the input_ids and attention_mask for the model."""
 
258
  prompt_text = f"{voice}: {text}"
259
  prompt_ids = tok(prompt_text, return_tensors="pt").input_ids.to(device)
260
 
 
261
  input_ids = torch.cat([
262
+ torch.tensor([[START_TOKEN]], device=device, dtype=torch.long),
263
+ prompt_ids,
264
+ torch.tensor([[NEW_BLOCK]], device=device, dtype=torch.long)
265
  ], dim=1)
266
 
 
267
  attention_mask = torch.ones_like(input_ids)
268
  return input_ids, attention_mask
269
 
 
272
  async def tts(ws: WebSocket):
273
  await ws.accept()
274
  print("🔌 Client connected")
275
+ streamer = None
276
+ main_loop = asyncio.get_running_loop()
277
 
278
  try:
 
279
  req_text = await ws.receive_text()
280
  print(f"Received request: {req_text}")
281
  req = json.loads(req_text)
282
+ text = req.get("text", "Hallo Welt, wie geht es dir heute?")
283
+ voice = req.get("voice", "Jakob")
284
 
285
  if not text:
286
  print("⚠️ Request text is empty.")
287
+ await ws.close(code=1003, reason="Text cannot be empty")
288
  return
289
 
290
  print(f"Generating audio for: '{text}' with voice '{voice}'")
 
 
291
  ids, attn = build_prompt(text, voice)
292
+ masker.reset()
 
 
 
 
 
293
  streamer = AudioStreamer(ws, snac, masker, main_loop, device)
294
 
 
 
295
  print("Starting generation in background thread...")
296
  await asyncio.to_thread(
297
  model.generate,
298
  input_ids=ids,
299
  attention_mask=attn,
300
+ max_new_tokens=1500,
301
  logits_processor=[masker],
302
  stopping_criteria=stopping_criteria,
303
+ do_sample=False, # Using greedy decoding
 
304
  use_cache=True,
305
+ streamer=streamer
 
306
  )
307
  print("Generation thread finished.")
308
 
 
315
  except Exception as e:
316
  error_details = traceback.format_exc()
317
  print(f"❌ WS‑Error: {e}\n{error_details}", flush=True)
 
318
  error_payload = json.dumps({"error": str(e)})
319
  try:
320
  if ws.client_state.name == "CONNECTED":
321
+ await ws.send_text(error_payload)
322
  except Exception:
323
+ pass
 
324
  if ws.client_state.name == "CONNECTED":
325
+ await ws.close(code=1011)
326
  finally:
 
327
  if streamer:
328
  try:
 
329
  streamer.end()
330
  except Exception as e_end:
331
  print(f"Error during streamer.end(): {e_end}")
332
 
 
333
  print("Closing connection.")
334
  if ws.client_state.name == "CONNECTED":
335
  try:
336
+ await ws.close(code=1000)
337
  except RuntimeError as e_close:
 
338
  print(f"Runtime error closing websocket: {e_close}")
339
  except Exception as e_close_final:
340
  print(f"Error closing websocket: {e_close_final}")
 
346
  if __name__ == "__main__":
347
  import uvicorn
348
  print("Starting Uvicorn server...")
349
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")