# app.py ────────────────────────────────────────────────────────────── import os import json import torch import asyncio import traceback # Import traceback for better error logging from fastapi import FastAPI, WebSocket, WebSocketDisconnect from huggingface_hub import login from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StoppingCriteria, StoppingCriteriaList # Import BaseStreamer for the interface from transformers.generation.streamers import BaseStreamer from snac import SNAC # Ensure you have 'pip install snac' # --- Globals (populated in load_models) --- tok = None model = None snac = None masker = None stopping_criteria = None # actual_eos_token_id = None # Reverted to constant below device = "cuda" if torch.cuda.is_available() else "cpu" # 0) Login + Device --------------------------------------------------- HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: print("🔑 Logging in to Hugging Face Hub...") login(HF_TOKEN) # torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorch‑2.2‑Bug # 1) Konstanten ------------------------------------------------------- REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" START_TOKEN = 128259 NEW_BLOCK = 128257 # --- Reverted to using the hardcoded EOS token based on user belief --- EOS_TOKEN = 128258 # --- End Reverted EOS Token --- AUDIO_BASE = 128266 AUDIO_SPAN = 4096 * 7 # 28672 Codes CODEBOOK_SIZE = 4096 # Explicitly define the codebook size # Create AUDIO_IDS on the correct device later in load_models AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN) # 2) Logit‑Mask ------------------------------------------------------- # Uses the constant EOS_TOKEN class AudioMask(LogitsProcessor): def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int): super().__init__() new_block_tensor = torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long) eos_tensor = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long) self.allow = torch.cat([new_block_tensor, audio_ids], dim=0) self.eos = eos_tensor self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0) self.sent_blocks = 0 def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow mask = torch.full_like(scores, float("-inf")) mask[:, current_allow] = 0 return scores + mask def reset(self): self.sent_blocks = 0 # 3) StoppingCriteria für EOS --------------------------------------- # Uses the constant EOS_TOKEN class EosStoppingCriteria(StoppingCriteria): def __init__(self, eos_token_id: int): self.eos_token_id = eos_token_id # No warning needed here as we are intentionally using the constant def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: if self.eos_token_id is None: return False if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id: print(f"StoppingCriteria: EOS detected (ID: {self.eos_token_id}).") # Add log return True return False # 4) Benutzerdefinierter AudioStreamer ------------------------------- class AudioStreamer(BaseStreamer): # Pass the constant EOS_TOKEN here too def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str, eos_token_id: int): self.ws = ws self.snac = snac_decoder self.masker = audio_mask self.loop = loop self.device = target_device self.eos_token_id = eos_token_id # Store constant EOS ID self.buf: list[int] = [] self.tasks = set() def _decode_block(self, block7: list[int]) -> bytes: """ Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes. NOTE: Extracts base code value (0-4095) using modulo, assuming input values represent (slot_offset + code_value). Maps extracted values using the structure potentially correct for Kartoffel_Orpheus. """ if len(block7) != 7: # print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.") return b"" # Less verbose logging try: # --- Extract base code value (0 to CODEBOOK_SIZE-1) for each slot using modulo --- code_val_0 = block7[0] % CODEBOOK_SIZE code_val_1 = block7[1] % CODEBOOK_SIZE code_val_2 = block7[2] % CODEBOOK_SIZE code_val_3 = block7[3] % CODEBOOK_SIZE code_val_4 = block7[4] % CODEBOOK_SIZE code_val_5 = block7[5] % CODEBOOK_SIZE code_val_6 = block7[6] % CODEBOOK_SIZE # --- Map the extracted code values to the SNAC codebooks (l1, l2, l3) --- l1 = [code_val_0] l2 = [code_val_1, code_val_4] l3 = [code_val_2, code_val_3, code_val_5, code_val_6] except IndexError: print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}") return b"" except Exception as e_map: print(f"Streamer Error: Exception during code value extraction/mapping: {e_map}. Block: {block7}") return b"" # --- Convert lists to tensors on the correct device --- try: codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0) codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0) codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0) codes = [codes_l1, codes_l2, codes_l3] except Exception as e_tensor: print(f"Streamer Error: Exception during tensor conversion: {e_tensor}. l1={l1}, l2={l2}, l3={l3}") return b"" # --- Decode using SNAC --- try: with torch.no_grad(): audio = self.snac.decode(codes)[0] except Exception as e_decode: print(f"Streamer Error: Exception during snac.decode: {e_decode}") # Add more details if needed, e.g., shapes: {[c.shape for c in codes]} return b"" # --- Post-processing --- try: audio_np = audio.squeeze().detach().cpu().numpy() audio_bytes = (audio_np * 32767).astype("int16").tobytes() return audio_bytes except Exception as e_post: print(f"Streamer Error: Exception during post-processing: {e_post}. Audio tensor shape: {audio.shape}") return b"" async def _send_audio_bytes(self, data: bytes): """Coroutine to send bytes over WebSocket.""" if not data: return try: await self.ws.send_bytes(data) except WebSocketDisconnect: # This is expected if client disconnects first, don't log error # print("Streamer: WebSocket disconnected during send.") pass except Exception as e: if "Cannot call \"send\" once a close message has been sent" in str(e) or \ "Connection is closed" in str(e): # This is expected if client disconnects during generation, suppress repetitive logs pass else: print(f"Streamer: Error sending bytes: {e}") def put(self, value: torch.LongTensor): """ Receives new token IDs (Tensor) from generate(). Processes tokens, decodes full blocks, and schedules sending. """ if value.numel() == 0: return new_token_ids = value.squeeze().cpu().tolist() if isinstance(new_token_ids, int): new_token_ids = [new_token_ids] for t in new_token_ids: # No need to check for EOS here, StoppingCriteria handles it if t == NEW_BLOCK: self.buf.clear() continue # Use the constant EOS_TOKEN for comparison if needed (e.g. for logging) if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN: self.buf.append(t - AUDIO_BASE) # Store value relative to base # else: # Optionally log ignored tokens # if t != self.eos_token_id: # Don't warn about the EOS token itself # print(f"Streamer Warning: Ignoring unexpected token {t}") if len(self.buf) == 7: audio_bytes = self._decode_block(self.buf) self.buf.clear() if audio_bytes: future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop) self.tasks.add(future) future.add_done_callback(self.tasks.discard) if self.masker.sent_blocks == 0: self.masker.sent_blocks = 1 def end(self): """Called by generate() when generation finishes.""" if len(self.buf) > 0: print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.") self.buf.clear() pass # 5) FastAPI App ------------------------------------------------------ app = FastAPI() @app.on_event("startup") async def load_models_startup(): # Keep global references, but EOS_TOKEN is now a constant again global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU print(f"🚀 Starting up on device: {device}") print("⏳ Lade Modelle …", flush=True) tok = AutoTokenizer.from_pretrained(REPO) print("Tokenizer loaded.") snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) print(f"SNAC loaded to {device}.") model_dtype = torch.float32 if device == "cuda": if torch.cuda.is_bf16_supported(): model_dtype = torch.bfloat16 print("Using bfloat16 for model.") else: model_dtype = torch.float16 print("Using float16 for model.") model = AutoModelForCausalLM.from_pretrained( REPO, device_map={"": 0} if device == "cuda" else None, torch_dtype=model_dtype, low_cpu_mem_usage=True, ) print(f"Model loaded to {model.device} with dtype {model.dtype}.") model.eval() # --- Print comparison for EOS token IDs but use the constant --- conf_eos = model.config.eos_token_id tok_eos = tok.eos_token_id print(f"Model Config EOS ID: {conf_eos}") print(f"Tokenizer EOS ID: {tok_eos}") print(f"Using Constant EOS_TOKEN: {EOS_TOKEN}") # State the used constant if conf_eos != EOS_TOKEN or tok_eos != EOS_TOKEN: print(f"⚠️ WARNING: Constant EOS_TOKEN {EOS_TOKEN} differs from model/tokenizer IDs ({conf_eos}/{tok_eos}).") # --- End EOS comparison --- # Set pad_token_id if None (use the constant EOS) if model.config.pad_token_id is None: print(f"Setting model.config.pad_token_id to Constant EOS token ID ({EOS_TOKEN})") model.config.pad_token_id = EOS_TOKEN audio_ids_device = AUDIO_IDS_CPU.to(device) # Pass the constant EOS_TOKEN to the mask masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN) print("AudioMask initialized.") # Pass the constant EOS_TOKEN to the stopping criteria stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)]) print("StoppingCriteria initialized.") print("✅ Modelle geladen und bereit!", flush=True) @app.get("/") def hello(): return {"status": "ok", "message": "TTS Service is running"} # 6) Helper zum Prompt Bauen ------------------------------------------- def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]: """Builds the input_ids and attention_mask for the model.""" prompt_text = f"{voice}: {text}" prompt_ids = tok(prompt_text, return_tensors="pt").input_ids.to(device) input_ids = torch.cat([ torch.tensor([[START_TOKEN]], device=device, dtype=torch.long), prompt_ids, torch.tensor([[NEW_BLOCK]], device=device, dtype=torch.long) ], dim=1) attention_mask = torch.ones_like(input_ids) return input_ids, attention_mask # 7) WebSocket‑Endpoint (vereinfacht mit Streamer) --------------------- @app.websocket("/ws/tts") async def tts(ws: WebSocket): # No need for global actual_eos_token_id await ws.accept() print("🔌 Client connected") streamer = None main_loop = asyncio.get_running_loop() try: req_text = await ws.receive_text() print(f"Received request: {req_text}") req = json.loads(req_text) text = req.get("text", "Hallo Welt, wie geht es dir heute?") voice = req.get("voice", "Jakob") if not text: print("⚠️ Request text is empty.") await ws.close(code=1003, reason="Text cannot be empty") return print(f"Generating audio for: '{text}' with voice '{voice}'") ids, attn = build_prompt(text, voice) masker.reset() # Pass the constant EOS_TOKEN to streamer streamer = AudioStreamer(ws, snac, masker, main_loop, device, EOS_TOKEN) print("Starting generation in background thread...") # Use sampling parameters with anti-repetition measures await asyncio.to_thread( model.generate, input_ids=ids, attention_mask=attn, max_new_tokens=2500, # Or adjust as needed logits_processor=[masker], stopping_criteria=stopping_criteria, # --- Sampling Parameters with Anti-Repetition --- do_sample=True, temperature=0.6, # Adjust if needed top_p=0.9, # Adjust if needed repetition_penalty=1.2, # Increased (experiment!) no_repeat_ngram_size=4, # Added (experiment!) # --- End Sampling Parameters --- use_cache=True, streamer=streamer, eos_token_id=EOS_TOKEN # Explicitly pass constant EOS ID ) print("Generation thread finished.") except WebSocketDisconnect: print("🔌 Client disconnected.") except json.JSONDecodeError: print("❌ Invalid JSON received.") if ws.client_state.name == "CONNECTED": await ws.close(code=1003, reason="Invalid JSON format") except Exception as e: error_details = traceback.format_exc() print(f"❌ WS‑Error: {e}\n{error_details}", flush=True) error_payload = json.dumps({"error": str(e)}) try: if ws.client_state.name == "CONNECTED": await ws.send_text(error_payload) except Exception: pass if ws.client_state.name == "CONNECTED": await ws.close(code=1011) finally: if streamer: try: streamer.end() except Exception as e_end: print(f"Error during streamer.end(): {e_end}") print("Closing connection.") if ws.client_state.name == "CONNECTED": try: await ws.close(code=1000) except RuntimeError as e_close: if "Cannot call \"send\"" not in str(e_close) and "Connection is closed" not in str(e_close): print(f"Runtime error closing websocket: {e_close}") except Exception as e_close_final: print(f"Error closing websocket: {e_close_final}") elif ws.client_state.name != "DISCONNECTED": print(f"WebSocket final state: {ws.client_state.name}") print("Connection closed.") # 8) Dev‑Start -------------------------------------------------------- if __name__ == "__main__": import uvicorn print("Starting Uvicorn server...") uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")