# app.py ────────────────────────────────────────────────────────────── import os import json import torch import asyncio import traceback # Import traceback for better error logging from fastapi import FastAPI, WebSocket, WebSocketDisconnect from huggingface_hub import login from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StoppingCriteria, StoppingCriteriaList # Import BaseStreamer for the interface from transformers.generation.streamers import BaseStreamer from snac import SNAC # Ensure you have 'pip install snac' # --- Globals (populated in load_models) --- tok = None model = None snac = None masker = None stopping_criteria = None device = "cuda" if torch.cuda.is_available() else "cpu" # 0) Login + Device --------------------------------------------------- HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: print("🔑 Logging in to Hugging Face Hub...") login(HF_TOKEN) # torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorch‑2.2‑Bug # 1) Konstanten ------------------------------------------------------- REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" # CHUNK_TOKENS = 50 # Not directly used by us with the streamer approach START_TOKEN = 128259 NEW_BLOCK = 128257 EOS_TOKEN = 128258 AUDIO_BASE = 128266 AUDIO_SPAN = 4096 * 7 # 28672 Codes # Create AUDIO_IDS on the correct device later in load_models AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN) # 2) Logit‑Mask ------------------------------------------------------- class AudioMask(LogitsProcessor): def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int): super().__init__() # Allow NEW_BLOCK and all valid audio tokens initially self.allow = torch.cat([ torch.tensor([new_block_token_id], device=audio_ids.device), # Add NEW_BLOCK token ID audio_ids ], dim=0) self.eos = torch.tensor([eos_token_id], device=audio_ids.device) # Store EOS token ID as tensor self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0) # Precompute combined tensor self.sent_blocks = 0 # State: Number of audio blocks sent def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # Determine which tokens are allowed based on whether blocks have been sent current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow # Create a mask initialized to negative infinity mask = torch.full_like(scores, float("-inf")) # Set allowed token scores to 0 (effectively allowing them) mask[:, current_allow] = 0 # Apply the mask to the scores return scores + mask def reset(self): """Resets the state for a new generation request.""" self.sent_blocks = 0 # 3) StoppingCriteria für EOS --------------------------------------- # generate() needs explicit stopping criteria when using a streamer class EosStoppingCriteria(StoppingCriteria): def __init__(self, eos_token_id: int): self.eos_token_id = eos_token_id def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: # Check if the *last* generated token is the EOS token if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id: # print("StoppingCriteria: EOS detected.") return True return False # 4) Benutzerdefinierter AudioStreamer ------------------------------- class AudioStreamer(BaseStreamer): # --- Updated __init__ to accept target_device --- def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str): self.ws = ws self.snac = snac_decoder self.masker = audio_mask # Reference to the mask to update sent_blocks self.loop = loop # Event loop of the main thread for run_coroutine_threadsafe # --- Use the passed target_device --- self.device = target_device self.buf: list[int] = [] # Buffer for audio token values (AUDIO_BASE subtracted) self.tasks = set() # Keep track of pending send tasks def _decode_block(self, block7: list[int]) -> bytes: """ Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes. NOTE: The mapping from the 7 tokens to the 3 SNAC codebooks (l1, l2, l3) is based on the structure found in the previous while-loop version. If audio is distorted, this mapping is the primary suspect. Ensure this mapping is correct for the specific model! """ if len(block7) != 7: print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.") return b"" # Return empty bytes if block is incomplete # --- Mapping derived from previous user version (indices [0], [1,4], [2,3,5,6]) --- # This seems more likely to be correct for Kartoffel_Orpheus if the previous version worked. try: l1 = [block7[0]] # Index 0 l2 = [block7[1], block7[4]] # Indices 1, 4 l3 = [block7[2], block7[3], block7[5], block7[6]] # Indices 2, 3, 5, 6 except IndexError: print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}") return b"" # --- Alternative Hypothesis (commented out): Interleaving mapping --- # try: # l1 = [block7[0], block7[3], block7[6]] # Codebook 1 indices: 0, 3, 6 # l2 = [block7[1], block7[4]] # Codebook 2 indices: 1, 4 # l3 = [block7[2], block7[5]] # Codebook 3 indices: 2, 5 # except IndexError: # print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}") # return b"" # --- End Alternative Hypothesis --- # Convert lists to tensors on the correct device # Use self.device which was set correctly in __init__ codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0) codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0) codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0) codes = [codes_l1, codes_l2, codes_l3] # List of tensors for SNAC # Decode using SNAC with torch.no_grad(): # self.snac should already be on self.device from load_models_startup audio = self.snac.decode(codes)[0] # Decode expects list of tensors, result might have batch dim # Squeeze, move to CPU, convert to numpy audio_np = audio.squeeze().detach().cpu().numpy() # Convert to 16-bit PCM bytes audio_bytes = (audio_np * 32767).astype("int16").tobytes() return audio_bytes async def _send_audio_bytes(self, data: bytes): """Coroutine to send bytes over WebSocket.""" if not data: # Don't send empty bytes return try: await self.ws.send_bytes(data) # print(f"Streamer: Sent {len(data)} audio bytes.") except WebSocketDisconnect: print("Streamer: WebSocket disconnected during send.") except Exception as e: print(f"Streamer: Error sending bytes: {e}") def put(self, value: torch.LongTensor): """ Receives new token IDs (Tensor) from generate() (runs in worker thread). Processes tokens, decodes full blocks, and schedules sending via run_coroutine_threadsafe. """ # Ensure value is on CPU and flatten to a list of ints if value.numel() == 0: return new_token_ids = value.squeeze().tolist() if isinstance(new_token_ids, int): # Handle single token case new_token_ids = [new_token_ids] for t in new_token_ids: if t == EOS_TOKEN: # print("Streamer: EOS token encountered.") # EOS is handled by StoppingCriteria, no action needed here except maybe logging. break # Stop processing this batch if EOS is found if t == NEW_BLOCK: # print("Streamer: NEW_BLOCK token encountered.") # NEW_BLOCK indicates the start of audio, might reset buffer if needed self.buf.clear() continue # Move to the next token # Check if token is within the expected audio range if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN: # Store value relative to base (IMPORTANT for _decode_block) self.buf.append(t - AUDIO_BASE) else: # Log unexpected tokens (like START_TOKEN or others if generation goes wrong) # print(f"Streamer Warning: Ignoring unexpected token {t}") pass # Ignore tokens outside the audio range # If buffer has 7 tokens, decode and send if len(self.buf) == 7: audio_bytes = self._decode_block(self.buf) self.buf.clear() # Clear buffer after processing if audio_bytes: # Only send if decoding was successful # Schedule the async send function to run on the main event loop future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop) self.tasks.add(future) # Optional: Remove completed tasks to prevent memory leak if generation is very long future.add_done_callback(self.tasks.discard) # Allow EOS only after the first full block has been processed and scheduled for sending if self.masker.sent_blocks == 0: # print("Streamer: First audio block processed, allowing EOS.") self.masker.sent_blocks = 1 # Update state in the mask # Note: No need to explicitly wait for tasks here. put() should return quickly. def end(self): """Called by generate() when generation finishes.""" # Handle any remaining tokens in the buffer (optional, here we discard them) if len(self.buf) > 0: print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.") self.buf.clear() # Optional: Wait briefly for any outstanding send tasks to complete? # This is tricky because end() is sync. A robust solution might involve # signaling the WebSocket handler to wait before closing. # For simplicity, we rely on FastAPI/Uvicorn's graceful shutdown handling. # print(f"Streamer: Generation finished. Pending send tasks: {len(self.tasks)}") pass # 5) FastAPI App ------------------------------------------------------ app = FastAPI() @app.on_event("startup") async def load_models_startup(): # Make startup async if needed for future async loads global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU print(f"🚀 Starting up on device: {device}") print("⏳ Lade Modelle …", flush=True) tok = AutoTokenizer.from_pretrained(REPO) print("Tokenizer loaded.") # Load SNAC first (usually smaller) snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) # --- FIXED Print statement --- print(f"SNAC loaded to {device}.") # Use the global device variable # Load the main model # Determine appropriate dtype based on device and support model_dtype = torch.float32 # Default to float32 for CPU if device == "cuda": if torch.cuda.is_bf16_supported(): model_dtype = torch.bfloat16 print("Using bfloat16 for model.") else: model_dtype = torch.float16 # Fallback to float16 if bfloat16 not supported print("Using float16 for model.") model = AutoModelForCausalLM.from_pretrained( REPO, device_map={"": 0} if device == "cuda" else None, # Assign to GPU 0 if cuda torch_dtype=model_dtype, low_cpu_mem_usage=True, # Good practice for large models ) model.config.pad_token_id = model.config.eos_token_id # Set pad token print(f"Model loaded to {model.device} with dtype {model.dtype}.") # Ensure model is in evaluation mode model.eval() # Initialize AudioMask (needs AUDIO_IDS on the correct device) audio_ids_device = AUDIO_IDS_CPU.to(device) masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN) print("AudioMask initialized.") # Initialize StoppingCriteria # IMPORTANT: Create the list and add the criteria instance stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)]) print("StoppingCriteria initialized.") print("✅ Modelle geladen und bereit!", flush=True) @app.get("/") def hello(): return {"status": "ok", "message": "TTS Service is running"} # 6) Helper zum Prompt Bauen ------------------------------------------- def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]: """Builds the input_ids and attention_mask for the model.""" # Format: : prompt_text = f"{voice}: {text}" prompt_ids = tok(prompt_text, return_tensors="pt").input_ids.to(device) # Construct input_ids tensor input_ids = torch.cat([ torch.tensor([[START_TOKEN]], device=device, dtype=torch.long), # Start token prompt_ids, # Encoded prompt torch.tensor([[NEW_BLOCK]], device=device, dtype=torch.long) # New block token to trigger audio ], dim=1) # Create attention mask (all ones) attention_mask = torch.ones_like(input_ids) return input_ids, attention_mask # 7) WebSocket‑Endpoint (vereinfacht mit Streamer) --------------------- @app.websocket("/ws/tts") async def tts(ws: WebSocket): await ws.accept() print("🔌 Client connected") streamer = None # Initialize for finally block main_loop = asyncio.get_running_loop() # Get the current event loop try: # Receive configuration req_text = await ws.receive_text() print(f"Received request: {req_text}") req = json.loads(req_text) text = req.get("text", "Hallo Welt, wie geht es dir heute?") # Default text voice = req.get("voice", "Jakob") # Default voice if not text: print("⚠️ Request text is empty.") await ws.close(code=1003, reason="Text cannot be empty") # 1003 = Cannot accept data type return print(f"Generating audio for: '{text}' with voice '{voice}'") # Prepare prompt ids, attn = build_prompt(text, voice) # --- Reset stateful components --- masker.reset() # CRITICAL: Reset the mask state for the new request # --- Create Streamer Instance --- # --- Pass the global 'device' variable --- streamer = AudioStreamer(ws, snac, masker, main_loop, device) # --- Run model.generate in a separate thread --- # This prevents blocking the main FastAPI event loop print("Starting generation in background thread...") await asyncio.to_thread( model.generate, input_ids=ids, attention_mask=attn, max_new_tokens=1500, # Limit generation length (adjust as needed) logits_processor=[masker], stopping_criteria=stopping_criteria, do_sample=False, # Use greedy decoding for potentially more stable audio # do_sample=True, temperature=0.7, top_p=0.95, # Or use sampling use_cache=True, streamer=streamer # Pass the custom streamer # No need to manage past_key_values manually ) print("Generation thread finished.") except WebSocketDisconnect: print("🔌 Client disconnected.") except json.JSONDecodeError: print("❌ Invalid JSON received.") if ws.client_state.name == "CONNECTED": await ws.close(code=1003, reason="Invalid JSON format") except Exception as e: error_details = traceback.format_exc() print(f"❌ WS‑Error: {e}\n{error_details}", flush=True) # Try to send an error message before closing, if possible error_payload = json.dumps({"error": str(e)}) try: if ws.client_state.name == "CONNECTED": await ws.send_text(error_payload) # Send error as text/json except Exception: pass # Ignore error during error reporting # Close with internal server error code if ws.client_state.name == "CONNECTED": await ws.close(code=1011) # 1011 = Internal Server Error finally: # Ensure streamer's end method is called if it exists if streamer: try: # print("Calling streamer.end()") streamer.end() except Exception as e_end: print(f"Error during streamer.end(): {e_end}") # Ensure WebSocket is closed print("Closing connection.") if ws.client_state.name == "CONNECTED": try: await ws.close(code=1000) # 1000 = Normal Closure except RuntimeError as e_close: # Can happen if connection is already closing/closed print(f"Runtime error closing websocket: {e_close}") except Exception as e_close_final: print(f"Error closing websocket: {e_close_final}") elif ws.client_state.name != "DISCONNECTED": print(f"WebSocket final state: {ws.client_state.name}") print("Connection closed.") # 8) Dev‑Start -------------------------------------------------------- if __name__ == "__main__": import uvicorn print("Starting Uvicorn server...") # Use reload=True only for development, remove for production # Consider adding --workers 1 if you experience issues with multiple workers and global state/GPU memory uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info") #, reload=True)