# app.py ───────────────────────────────────────────────────────────── import os, json, asyncio, torch from fastapi import FastAPI, WebSocket, WebSocketDisconnect from huggingface_hub import login from transformers import (AutoTokenizer, AutoModelForCausalLM, LogitsProcessor) from transformers.generation.utils import Cache from snac import SNAC # ── 0. HF‑Login & Device ───────────────────────────────────────────── HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: login(HF_TOKEN) device = "cuda" if torch.cuda.is_available() else "cpu" # Flash‑Attention‑Bug in PyTorch 2.2.x umgehen torch.backends.cuda.enable_flash_sdp(False) # ── 1. Konstanten ──────────────────────────────────────────────────── REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" CHUNK_TOKENS = 50 # pro mini‑generate START_TOKEN = 128259 NEW_BLOCK_TOKEN = 128257 EOS_TOKEN = 128258 AUDIO_BASE = 128266 # erster Audio‑Code VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096) # ── 2. Dynamischer Logit‑Masker ────────────────────────────────────── class DynamicAudioMask(LogitsProcessor): """ blockt EOS, bis mindestens `min_audio_blocks` gesendet wurden """ def __init__(self, audio_ids: torch.Tensor, min_audio_blocks: int = 1): super().__init__() self.audio_ids = audio_ids self.ctrl_ids = torch.tensor([NEW_BLOCK_TOKEN], device=audio_ids.device) self.min_blocks = min_audio_blocks self.blocks_done = 0 def __call__(self, input_ids, scores): allowed = torch.cat([self.audio_ids, self.ctrl_ids]) if self.blocks_done >= self.min_blocks: # jetzt darf EOS dazu allowed = torch.cat([allowed, torch.tensor([EOS_TOKEN], device=scores.device)]) mask = torch.full_like(scores, float("-inf")) mask[:, allowed] = 0 return scores + mask # ── 3. FastAPI Grundgerüst ─────────────────────────────────────────── app = FastAPI() @app.get("/") async def ping(): return {"msg": "Orpheus‑TTS up & running"} @app.on_event("startup") async def load_models(): global tok, model, snac, masker print("⏳ Lade Modelle …") tok = AutoTokenizer.from_pretrained(REPO) snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) model = AutoModelForCausalLM.from_pretrained( REPO, low_cpu_mem_usage=True, device_map={"": 0} if device == "cuda" else None, torch_dtype=torch.bfloat16 if device == "cuda" else None, ) model.config.pad_token_id = model.config.eos_token_id model.config.use_cache = True masker = DynamicAudioMask(VALID_AUDIO_IDS.to(device)) print("✅ Modelle geladen") # ── 4. Hilfs‑Funktionen ────────────────────────────────────────────── def build_inputs(text: str, voice: str): prompt = f"{voice}: {text}" ids = tok(prompt, return_tensors="pt").input_ids.to(device) ids = torch.cat( [ torch.tensor([[START_TOKEN]], device=device), ids, torch.tensor([[128009, 128260]], device=device) ], dim=1, ) attn = torch.ones_like(ids) return ids, attn def decode_block(block7: list[int]) -> bytes: l1, l2, l3 = [], [], [] b = block7 l1.append(b[0]) l2.append(b[1] - 4096) l3.extend([b[2] - 8192, b[3] - 12288]) l2.append(b[4] - 16384) l3.extend([b[5] - 20480, b[6] - 24576]) codes = [ torch.tensor(l1, device=device).unsqueeze(0), torch.tensor(l2, device=device).unsqueeze(0), torch.tensor(l3, device=device).unsqueeze(0), ] audio = snac.decode(codes).squeeze().cpu().numpy() return (audio * 32767).astype("int16").tobytes() # ── 5. WebSocket‑TTS‑Endpoint ─────────────────────────────────────── @app.websocket("/ws/tts") async def tts(ws: WebSocket): await ws.accept() try: req = json.loads(await ws.receive_text()) text = req.get("text", "") voice = req.get("voice", "Jakob") ids, attn = build_inputs(text, voice) # vollständiger Prompt past = None last_tok = None # <- NEU buf = [] while True: out = model.generate( input_ids = ids if past is None else torch.tensor([[last_tok]], device=device), attention_mask = attn if past is None else None, past_key_values = past, max_new_tokens = CHUNK_TOKENS, logits_processor= [masker], do_sample=True, temperature=0.7, top_p=0.95, return_dict_in_generate=True, use_cache=True, return_legacy_cache=True, # <- Warnung unterdrücken ) # ----- Cache & neue Token -------------------------------------- pkv = out.past_key_values if isinstance(pkv, Cache): # HF >= 4.47 pkv = pkv.to_legacy() past = pkv new = out.sequences[0, -out.num_generated_tokens :].tolist() print("new tokens:", new[:20]) # Debug‑Print if not new: # Safety – nichts erzeugt raise StopIteration # ----- Token‑Handling ------------------------------------------ for t in new: last_tok = t # speichern für nächste Runde if t == EOS_TOKEN: raise StopIteration if t == NEW_BLOCK_TOKEN: buf.clear() continue buf.append(t - AUDIO_BASE) if len(buf) == 7: await ws.send_bytes(decode_block(buf)) buf.clear() masker.blocks_done += 1 # nach 1. Block darf EOS # ab nächster Runde nur 1 Token + Cache ids, attn = None, None except (StopIteration, WebSocketDisconnect): pass except Exception as e: print("❌ WS‑Error:", e) if ws.client_state.name != "DISCONNECTED": await ws.close(code=1011) finally: if ws.client_state.name != "DISCONNECTED": try: await ws.close() except RuntimeError: pass # ── 6. Lokaler Start (uvicorn) ─────────────────────────────────────── if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=7860)