# app.py ────────────────────────────────────────────────────────────── import os import json import torch import asyncio import traceback # Import traceback for better error logging from fastapi import FastAPI, WebSocket, WebSocketDisconnect from huggingface_hub import login from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StoppingCriteria, StoppingCriteriaList # Import BaseStreamer for the interface from transformers.generation.streamers import BaseStreamer from snac import SNAC # Ensure you have 'pip install snac' # --- Globals (populated in load_models) --- tok = None model = None snac = None masker = None stopping_criteria = None device = "cuda" if torch.cuda.is_available() else "cpu" # 0) Login + Device --------------------------------------------------- HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: print("🔑 Logging in to Hugging Face Hub...") # Consider adding error handling for login failure if necessary login(HF_TOKEN) # torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorch‑2.2‑Bug # 1) Konstanten ------------------------------------------------------- REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" START_TOKEN = 128259 NEW_BLOCK = 128257 # Token indicating start of audio generation EOS_TOKEN = 128258 # End Of Speech token AUDIO_BASE = 128266 # Base ID for audio tokens AUDIO_SPAN = 4096 * 7 # 7 codebooks * 4096 codes per book = 28672 possible audio tokens # Create AUDIO_IDS on the correct device later in load_models AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN) # 2) Logit‑Mask ------------------------------------------------------- class AudioMask(LogitsProcessor): """ Manages allowed tokens during generation. - Initially allows NEW_BLOCK and AUDIO tokens. - Allows EOS_TOKEN only after at least one audio block has been sent. """ def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int): super().__init__() # Allow NEW_BLOCK and all valid audio tokens initially self.allow_initial = torch.cat([ torch.tensor([new_block_token_id], device=audio_ids.device), audio_ids ], dim=0) self.eos = torch.tensor([eos_token_id], device=audio_ids.device) # Precompute combined tensor for allowing audio, NEW_BLOCK, and EOS self.allow_with_eos = torch.cat([self.allow_initial, self.eos], dim=0) self.sent_blocks = 0 # State: Number of audio blocks sent def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # Determine which tokens are allowed based on whether blocks have been sent current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow_initial # Create a mask initialized to negative infinity mask = torch.full_like(scores, float("-inf")) # Set allowed token scores to 0 (effectively allowing them) mask[:, current_allow] = 0 # Apply the mask to the scores return scores + mask def reset(self): """Resets the state for a new generation request.""" self.sent_blocks = 0 # 3) StoppingCriteria für EOS --------------------------------------- # generate() needs explicit stopping criteria when using a streamer class EosStoppingCriteria(StoppingCriteria): def __init__(self, eos_token_id: int): self.eos_token_id = eos_token_id def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: # Check if the *last* generated token is the EOS token # Check input_ids shape to prevent index error on first token if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id: print("StoppingCriteria: EOS detected.") return True return False # 4) Benutzerdefinierter AudioStreamer ------------------------------- class AudioStreamer(BaseStreamer): """ Custom streamer to process audio tokens, decode them using SNAC, and send audio bytes over a WebSocket. """ # Added target_device parameter def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str): self.ws = ws self.snac = snac_decoder self.masker = audio_mask # Reference to the mask to update sent_blocks self.loop = loop # Event loop of the main thread for run_coroutine_threadsafe # Use the passed target_device self.device = target_device self.buf: list[int] = [] # Buffer for audio token values (AUDIO_BASE subtracted) self.tasks = set() # Keep track of pending send tasks def _decode_block(self, block7: list[int]) -> bytes: """ Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes. NOTE: The mapping from the 7 tokens to the 3 SNAC codebooks (l1, l2, l3) is CRITICAL and based on the structure used by the specific model. This implementation uses the mapping derived from the user's previous code. If audio is distorted, try the alternative mapping commented out below. """ if len(block7) != 7: print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.") return b"" # Return empty bytes if block is incomplete # --- Mapping based on user's previous version --- try: l1 = [block7[0]] # Index 0 l2 = [block7[1], block7[4]] # Indices 1, 4 l3 = [block7[2], block7[3], block7[5], block7[6]] # Indices 2, 3, 5, 6 # --- Alternative Hypothesis Mapping (Try if above fails) --- # l1 = [block7[0], block7[3], block7[6]] # Indices 0, 3, 6 # l2 = [block7[1], block7[4]] # Indices 1, 4 # l3 = [block7[2], block7[5]] # Indices 2, 5 except IndexError as e: print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}, Error: {e}") return b"" # Convert lists to tensors on the correct device try: codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0) codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0) codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0) codes = [codes_l1, codes_l2, codes_l3] # List of tensors for SNAC except Exception as e: print(f"Streamer Error: Failed converting lists to tensors. Error: {e}") return b"" # Decode using SNAC try: with torch.no_grad(): # Ensure snac_decoder is on the correct device already (done via .to(device)) audio = self.snac.decode(codes)[0] # Decode expects list of tensors, result might have batch dim except Exception as e: print(f"Streamer Error: snac.decode failed. Input shapes: {[c.shape for c in codes]}. Error: {e}") return b"" # Squeeze, move to CPU, convert to numpy audio_np = audio.squeeze().detach().cpu().numpy() # Convert to 16-bit PCM bytes audio_bytes = (audio_np * 32767).astype("int16").tobytes() return audio_bytes async def _send_audio_bytes(self, data: bytes): """Coroutine to send bytes over WebSocket.""" if not data: # Don't send empty bytes return try: await self.ws.send_bytes(data) # print(f"Streamer: Sent {len(data)} audio bytes.") # Optional: Debug log except WebSocketDisconnect: print("Streamer: WebSocket disconnected during send.") except Exception as e: print(f"Streamer: Error sending bytes: {e}") def put(self, value: torch.LongTensor): """ Receives new token IDs (Tensor) from generate() (runs in worker thread). Processes tokens, decodes full blocks, and schedules sending via run_coroutine_threadsafe. """ # Ensure value is on CPU and flatten to a list of ints if value.numel() == 0: return # Handle potential shape issues, ensure it's iterable try: new_token_ids = value.view(-1).tolist() except Exception as e: print(f"Streamer Error: Could not process incoming tensor: {value}, Error: {e}") return for t in new_token_ids: if t == EOS_TOKEN: # print("Streamer: EOS token encountered.") # EOS is handled by StoppingCriteria, no action needed here except maybe logging. break # Stop processing this batch if EOS is found if t == NEW_BLOCK: