Tomtom84 commited on
Commit
641d199
·
verified ·
1 Parent(s): dbb4a9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -239
app.py CHANGED
@@ -24,37 +24,43 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
24
  HF_TOKEN = os.getenv("HF_TOKEN")
25
  if HF_TOKEN:
26
  print("🔑 Logging in to Hugging Face Hub...")
 
27
  login(HF_TOKEN)
28
 
29
  # torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorch‑2.2‑Bug
30
 
31
  # 1) Konstanten -------------------------------------------------------
32
  REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
33
- # CHUNK_TOKENS = 50 # Not directly used by us with the streamer approach
34
  START_TOKEN = 128259
35
- NEW_BLOCK = 128257
36
- EOS_TOKEN = 128258
37
- AUDIO_BASE = 128266
38
- AUDIO_SPAN = 4096 * 7 # 28672 Codes
39
  # Create AUDIO_IDS on the correct device later in load_models
40
  AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
41
 
42
  # 2) Logit‑Mask -------------------------------------------------------
43
  class AudioMask(LogitsProcessor):
 
 
 
 
 
44
  def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
45
  super().__init__()
46
  # Allow NEW_BLOCK and all valid audio tokens initially
47
- self.allow = torch.cat([
48
- torch.tensor([new_block_token_id], device=audio_ids.device), # Add NEW_BLOCK token ID
49
  audio_ids
50
  ], dim=0)
51
- self.eos = torch.tensor([eos_token_id], device=audio_ids.device) # Store EOS token ID as tensor
52
- self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0) # Precompute combined tensor
 
53
  self.sent_blocks = 0 # State: Number of audio blocks sent
54
 
55
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
56
  # Determine which tokens are allowed based on whether blocks have been sent
57
- current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
58
 
59
  # Create a mask initialized to negative infinity
60
  mask = torch.full_like(scores, float("-inf"))
@@ -75,19 +81,26 @@ class EosStoppingCriteria(StoppingCriteria):
75
 
76
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
77
  # Check if the *last* generated token is the EOS token
 
78
  if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
79
- # print("StoppingCriteria: EOS detected.")
80
  return True
81
  return False
82
 
83
  # 4) Benutzerdefinierter AudioStreamer -------------------------------
84
  class AudioStreamer(BaseStreamer):
85
- def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop):
 
 
 
 
 
86
  self.ws = ws
87
  self.snac = snac_decoder
88
  self.masker = audio_mask # Reference to the mask to update sent_blocks
89
  self.loop = loop # Event loop of the main thread for run_coroutine_threadsafe
90
- self.device = snac_decoder.device # Get device from the decoder
 
91
  self.buf: list[int] = [] # Buffer for audio token values (AUDIO_BASE subtracted)
92
  self.tasks = set() # Keep track of pending send tasks
93
 
@@ -95,34 +108,46 @@ class AudioStreamer(BaseStreamer):
95
  """
96
  Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
97
  NOTE: The mapping from the 7 tokens to the 3 SNAC codebooks (l1, l2, l3)
98
- is based on a common interleaving hypothesis. Verify if model docs specify otherwise.
 
 
99
  """
100
  if len(block7) != 7:
101
  print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
102
  return b"" # Return empty bytes if block is incomplete
103
 
104
- # --- Hypothesis: Interleaving mapping ---
105
- # Assumes 7 tokens map to 3 codebooks like this:
106
- # Codebook 1 (l1) uses tokens at indices 0, 3, 6
107
- # Codebook 2 (l2) uses tokens at indices 1, 4
108
- # Codebook 3 (l3) uses tokens at indices 2, 5
109
  try:
110
- l1 = [block7[0], block7[3], block7[6]]
111
- l2 = [block7[1], block7[4]]
112
- l3 = [block7[2], block7[5]]
113
- except IndexError:
114
- print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}")
 
 
 
 
115
  return b""
116
 
117
  # Convert lists to tensors on the correct device
118
- codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0)
119
- codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0)
120
- codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0)
121
- codes = [codes_l1, codes_l2, codes_l3] # List of tensors for SNAC
 
 
 
 
122
 
123
  # Decode using SNAC
124
- with torch.no_grad():
125
- audio = self.snac.decode(codes)[0] # Decode expects list of tensors, result might have batch dim
 
 
 
 
 
 
126
 
127
  # Squeeze, move to CPU, convert to numpy
128
  audio_np = audio.squeeze().detach().cpu().numpy()
@@ -137,7 +162,7 @@ class AudioStreamer(BaseStreamer):
137
  return
138
  try:
139
  await self.ws.send_bytes(data)
140
- # print(f"Streamer: Sent {len(data)} audio bytes.")
141
  except WebSocketDisconnect:
142
  print("Streamer: WebSocket disconnected during send.")
143
  except Exception as e:
@@ -151,9 +176,12 @@ class AudioStreamer(BaseStreamer):
151
  # Ensure value is on CPU and flatten to a list of ints
152
  if value.numel() == 0:
153
  return
154
- new_token_ids = value.squeeze().tolist()
155
- if isinstance(new_token_ids, int): # Handle single token case
156
- new_token_ids = [new_token_ids]
 
 
 
157
 
158
  for t in new_token_ids:
159
  if t == EOS_TOKEN:
@@ -161,208 +189,4 @@ class AudioStreamer(BaseStreamer):
161
  # EOS is handled by StoppingCriteria, no action needed here except maybe logging.
162
  break # Stop processing this batch if EOS is found
163
 
164
- if t == NEW_BLOCK:
165
- # print("Streamer: NEW_BLOCK token encountered.")
166
- # NEW_BLOCK indicates the start of audio, might reset buffer if needed
167
- self.buf.clear()
168
- continue # Move to the next token
169
-
170
- # Check if token is within the expected audio range
171
- if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
172
- self.buf.append(t - AUDIO_BASE) # Store value relative to base
173
- else:
174
- # Log unexpected tokens (like START_TOKEN or others if generation goes wrong)
175
- # print(f"Streamer Warning: Ignoring unexpected token {t}")
176
- pass # Ignore tokens outside the audio range
177
-
178
- # If buffer has 7 tokens, decode and send
179
- if len(self.buf) == 7:
180
- audio_bytes = self._decode_block(self.buf)
181
- self.buf.clear() # Clear buffer after processing
182
-
183
- if audio_bytes: # Only send if decoding was successful
184
- # Schedule the async send function to run on the main event loop
185
- future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
186
- self.tasks.add(future)
187
- # Optional: Remove completed tasks to prevent memory leak if generation is very long
188
- future.add_done_callback(self.tasks.discard)
189
-
190
-
191
- # Allow EOS only after the first full block has been processed and scheduled for sending
192
- if self.masker.sent_blocks == 0:
193
- # print("Streamer: First audio block processed, allowing EOS.")
194
- self.masker.sent_blocks = 1 # Update state in the mask
195
-
196
- # Note: No need to explicitly wait for tasks here. put() should return quickly.
197
-
198
- def end(self):
199
- """Called by generate() when generation finishes."""
200
- # Handle any remaining tokens in the buffer (optional, here we discard them)
201
- if len(self.buf) > 0:
202
- print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
203
- self.buf.clear()
204
-
205
- # Optional: Wait briefly for any outstanding send tasks to complete?
206
- # This is tricky because end() is sync. A robust solution might involve
207
- # signaling the WebSocket handler to wait before closing.
208
- # For simplicity, we rely on FastAPI/Uvicorn's graceful shutdown handling.
209
- # print(f"Streamer: Generation finished. Pending send tasks: {len(self.tasks)}")
210
- pass
211
-
212
- # 5) FastAPI App ------------------------------------------------------
213
- app = FastAPI()
214
-
215
- @app.on_event("startup")
216
- async def load_models_startup(): # Make startup async if needed for future async loads
217
- global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU
218
-
219
- print(f"🚀 Starting up on device: {device}")
220
- print("⏳ Lade Modelle …", flush=True)
221
-
222
- tok = AutoTokenizer.from_pretrained(REPO)
223
- print("Tokenizer loaded.")
224
-
225
- # Load SNAC first (usually smaller)
226
- snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
227
- print(f"SNAC loaded to {snac.device}.")
228
-
229
- # Load the main model
230
- model = AutoModelForCausalLM.from_pretrained(
231
- REPO,
232
- device_map={"": 0} if device == "cuda" else None, # Assign to GPU 0 if cuda
233
- torch_dtype=torch.bfloat16 if device == "cuda" and torch.cuda.is_bf16_supported() else torch.float32, # Use bfloat16 if supported
234
- low_cpu_mem_usage=True, # Good practice for large models
235
- )
236
- model.config.pad_token_id = model.config.eos_token_id # Set pad token
237
- print(f"Model loaded to {model.device}.")
238
-
239
- # Ensure model is in evaluation mode
240
- model.eval()
241
-
242
- # Initialize AudioMask (needs AUDIO_IDS on the correct device)
243
- audio_ids_device = AUDIO_IDS_CPU.to(device)
244
- masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN)
245
- print("AudioMask initialized.")
246
-
247
- # Initialize StoppingCriteria
248
- # IMPORTANT: Create the list and add the criteria instance
249
- stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)])
250
- print("StoppingCriteria initialized.")
251
-
252
- print("✅ Modelle geladen und bereit!", flush=True)
253
-
254
- @app.get("/")
255
- def hello():
256
- return {"status": "ok", "message": "TTS Service is running"}
257
-
258
- # 6) Helper zum Prompt Bauen -------------------------------------------
259
- def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]:
260
- """Builds the input_ids and attention_mask for the model."""
261
- # Format: <START> <VOICE>: <TEXT> <NEW_BLOCK>
262
- prompt_text = f"{voice}: {text}"
263
- prompt_ids = tok(prompt_text, return_tensors="pt").input_ids.to(device)
264
-
265
- # Construct input_ids tensor
266
- input_ids = torch.cat([
267
- torch.tensor([[START_TOKEN]], device=device), # Start token
268
- prompt_ids, # Encoded prompt
269
- torch.tensor([[NEW_BLOCK]], device=device) # New block token to trigger audio
270
- ], dim=1)
271
-
272
- # Create attention mask (all ones)
273
- attention_mask = torch.ones_like(input_ids)
274
- return input_ids, attention_mask
275
-
276
- # 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
277
- @app.websocket("/ws/tts")
278
- async def tts(ws: WebSocket):
279
- await ws.accept()
280
- print(" клиент подключился") # Client connected
281
- streamer = None # Initialize for finally block
282
- main_loop = asyncio.get_running_loop() # Get the current event loop
283
-
284
- try:
285
- # Receive configuration
286
- req_text = await ws.receive_text()
287
- print(f"Received request: {req_text}")
288
- req = json.loads(req_text)
289
- text = req.get("text", "Hallo Welt, wie geht es dir heute?") # Default text
290
- voice = req.get("voice", "Jakob") # Default voice
291
-
292
- if not text:
293
- await ws.close(code=1003, reason="Text cannot be empty")
294
- return
295
-
296
- print(f"Generating audio for: '{text}' with voice '{voice}'")
297
-
298
- # Prepare prompt
299
- ids, attn = build_prompt(text, voice)
300
-
301
- # --- Reset stateful components ---
302
- masker.reset() # CRITICAL: Reset the mask state for the new request
303
-
304
- # --- Create Streamer Instance ---
305
- streamer = AudioStreamer(ws, snac, masker, main_loop)
306
-
307
- # --- Run model.generate in a separate thread ---
308
- # This prevents blocking the main FastAPI event loop
309
- print("Starting generation...")
310
- await asyncio.to_thread(
311
- model.generate,
312
- input_ids=ids,
313
- attention_mask=attn,
314
- max_new_tokens=1500, # Limit generation length (adjust as needed)
315
- logits_processor=[masker],
316
- stopping_criteria=stopping_criteria,
317
- do_sample=False, # Use greedy decoding for potentially more stable audio
318
- # do_sample=True, temperature=0.7, top_p=0.95, # Or use sampling
319
- use_cache=True,
320
- streamer=streamer # Pass the custom streamer
321
- # No need to manage past_key_values manually
322
- )
323
- print("Generation finished.")
324
-
325
- except WebSocketDisconnect:
326
- print("Client disconnected.")
327
- except json.JSONDecodeError:
328
- print("❌ Invalid JSON received.")
329
- await ws.close(code=1003, reason="Invalid JSON format") # 1003 = Cannot accept data type
330
- except Exception as e:
331
- error_details = traceback.format_exc()
332
- print(f"❌ WS‑Error: {e}\n{error_details}", flush=True)
333
- # Try to send an error message before closing, if possible
334
- error_payload = json.dumps({"error": str(e)})
335
- try:
336
- if ws.client_state.name == "CONNECTED":
337
- await ws.send_text(error_payload) # Send error as text/json
338
- except Exception:
339
- pass # Ignore error during error reporting
340
- # Close with internal server error code
341
- if ws.client_state.name == "CONNECTED":
342
- await ws.close(code=1011) # 1011 = Internal Server Error
343
- finally:
344
- # Ensure streamer's end method is called if it exists
345
- if streamer:
346
- try:
347
- streamer.end()
348
- except Exception as e_end:
349
- print(f"Error during streamer.end(): {e_end}")
350
-
351
- # Ensure WebSocket is closed
352
- print("Closing connection.")
353
- if ws.client_state.name != "DISCONNECTED":
354
- try:
355
- await ws.close(code=1000) # 1000 = Normal Closure
356
- except RuntimeError as e_close:
357
- # Can happen if connection is already closing/closed
358
- print(f"Runtime error closing websocket: {e_close}")
359
- except Exception as e_close_final:
360
- print(f"Error closing websocket: {e_close_final}")
361
- print("Connection closed.")
362
-
363
- # 8) Dev‑Start --------------------------------------------------------
364
- if __name__ == "__main__":
365
- import uvicorn
366
- print("Starting Uvicorn server...")
367
- # Use reload=True only for development, remove for production
368
- uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info") #, reload=True)
 
24
  HF_TOKEN = os.getenv("HF_TOKEN")
25
  if HF_TOKEN:
26
  print("🔑 Logging in to Hugging Face Hub...")
27
+ # Consider adding error handling for login failure if necessary
28
  login(HF_TOKEN)
29
 
30
  # torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorch‑2.2‑Bug
31
 
32
  # 1) Konstanten -------------------------------------------------------
33
  REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
 
34
  START_TOKEN = 128259
35
+ NEW_BLOCK = 128257 # Token indicating start of audio generation
36
+ EOS_TOKEN = 128258 # End Of Speech token
37
+ AUDIO_BASE = 128266 # Base ID for audio tokens
38
+ AUDIO_SPAN = 4096 * 7 # 7 codebooks * 4096 codes per book = 28672 possible audio tokens
39
  # Create AUDIO_IDS on the correct device later in load_models
40
  AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
41
 
42
  # 2) Logit‑Mask -------------------------------------------------------
43
  class AudioMask(LogitsProcessor):
44
+ """
45
+ Manages allowed tokens during generation.
46
+ - Initially allows NEW_BLOCK and AUDIO tokens.
47
+ - Allows EOS_TOKEN only after at least one audio block has been sent.
48
+ """
49
  def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
50
  super().__init__()
51
  # Allow NEW_BLOCK and all valid audio tokens initially
52
+ self.allow_initial = torch.cat([
53
+ torch.tensor([new_block_token_id], device=audio_ids.device),
54
  audio_ids
55
  ], dim=0)
56
+ self.eos = torch.tensor([eos_token_id], device=audio_ids.device)
57
+ # Precompute combined tensor for allowing audio, NEW_BLOCK, and EOS
58
+ self.allow_with_eos = torch.cat([self.allow_initial, self.eos], dim=0)
59
  self.sent_blocks = 0 # State: Number of audio blocks sent
60
 
61
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
62
  # Determine which tokens are allowed based on whether blocks have been sent
63
+ current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow_initial
64
 
65
  # Create a mask initialized to negative infinity
66
  mask = torch.full_like(scores, float("-inf"))
 
81
 
82
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
83
  # Check if the *last* generated token is the EOS token
84
+ # Check input_ids shape to prevent index error on first token
85
  if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
86
+ print("StoppingCriteria: EOS detected.")
87
  return True
88
  return False
89
 
90
  # 4) Benutzerdefinierter AudioStreamer -------------------------------
91
  class AudioStreamer(BaseStreamer):
92
+ """
93
+ Custom streamer to process audio tokens, decode them using SNAC,
94
+ and send audio bytes over a WebSocket.
95
+ """
96
+ # Added target_device parameter
97
+ def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str):
98
  self.ws = ws
99
  self.snac = snac_decoder
100
  self.masker = audio_mask # Reference to the mask to update sent_blocks
101
  self.loop = loop # Event loop of the main thread for run_coroutine_threadsafe
102
+ # Use the passed target_device
103
+ self.device = target_device
104
  self.buf: list[int] = [] # Buffer for audio token values (AUDIO_BASE subtracted)
105
  self.tasks = set() # Keep track of pending send tasks
106
 
 
108
  """
109
  Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
110
  NOTE: The mapping from the 7 tokens to the 3 SNAC codebooks (l1, l2, l3)
111
+ is CRITICAL and based on the structure used by the specific model.
112
+ This implementation uses the mapping derived from the user's previous code.
113
+ If audio is distorted, try the alternative mapping commented out below.
114
  """
115
  if len(block7) != 7:
116
  print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
117
  return b"" # Return empty bytes if block is incomplete
118
 
119
+ # --- Mapping based on user's previous version ---
 
 
 
 
120
  try:
121
+ l1 = [block7[0]] # Index 0
122
+ l2 = [block7[1], block7[4]] # Indices 1, 4
123
+ l3 = [block7[2], block7[3], block7[5], block7[6]] # Indices 2, 3, 5, 6
124
+ # --- Alternative Hypothesis Mapping (Try if above fails) ---
125
+ # l1 = [block7[0], block7[3], block7[6]] # Indices 0, 3, 6
126
+ # l2 = [block7[1], block7[4]] # Indices 1, 4
127
+ # l3 = [block7[2], block7[5]] # Indices 2, 5
128
+ except IndexError as e:
129
+ print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}, Error: {e}")
130
  return b""
131
 
132
  # Convert lists to tensors on the correct device
133
+ try:
134
+ codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0)
135
+ codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0)
136
+ codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0)
137
+ codes = [codes_l1, codes_l2, codes_l3] # List of tensors for SNAC
138
+ except Exception as e:
139
+ print(f"Streamer Error: Failed converting lists to tensors. Error: {e}")
140
+ return b""
141
 
142
  # Decode using SNAC
143
+ try:
144
+ with torch.no_grad():
145
+ # Ensure snac_decoder is on the correct device already (done via .to(device))
146
+ audio = self.snac.decode(codes)[0] # Decode expects list of tensors, result might have batch dim
147
+ except Exception as e:
148
+ print(f"Streamer Error: snac.decode failed. Input shapes: {[c.shape for c in codes]}. Error: {e}")
149
+ return b""
150
+
151
 
152
  # Squeeze, move to CPU, convert to numpy
153
  audio_np = audio.squeeze().detach().cpu().numpy()
 
162
  return
163
  try:
164
  await self.ws.send_bytes(data)
165
+ # print(f"Streamer: Sent {len(data)} audio bytes.") # Optional: Debug log
166
  except WebSocketDisconnect:
167
  print("Streamer: WebSocket disconnected during send.")
168
  except Exception as e:
 
176
  # Ensure value is on CPU and flatten to a list of ints
177
  if value.numel() == 0:
178
  return
179
+ # Handle potential shape issues, ensure it's iterable
180
+ try:
181
+ new_token_ids = value.view(-1).tolist()
182
+ except Exception as e:
183
+ print(f"Streamer Error: Could not process incoming tensor: {value}, Error: {e}")
184
+ return
185
 
186
  for t in new_token_ids:
187
  if t == EOS_TOKEN:
 
189
  # EOS is handled by StoppingCriteria, no action needed here except maybe logging.
190
  break # Stop processing this batch if EOS is found
191
 
192
+ if t == NEW_BLOCK: