# app.py ───────────────────────────────────────────────────────────── import os, json, asyncio, torch from fastapi import FastAPI, WebSocket, WebSocketDisconnect from huggingface_hub import login from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor from snac import SNAC # ── 0. HF‑Auth & Device ────────────────────────────────────────────── HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: login(HF_TOKEN) device = "cuda" if torch.cuda.is_available() else "cpu" # Flash‑Attention‑Bug in PyTorch 2.2.x torch.backends.cuda.enable_flash_sdp(False) # ── 1. Konstanten ──────────────────────────────────────────────────── REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" CHUNK_TOKENS = 50 START_TOKEN = 128259 NEW_BLOCK_TOKEN = 128257 EOS_TOKEN = 128258 AUDIO_BASE = 128266 VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096) # ── 2. Logit‑Processor zum Maskieren ──────────────────────────────── class AudioLogitMask(LogitsProcessor): def __init__(self, allowed_ids: torch.Tensor): super().__init__() self.allowed = allowed_ids def __call__(self, input_ids, scores): # scores shape: [batch, vocab] mask = torch.full_like(scores, float("-inf")) mask[:, self.allowed] = 0 return scores + mask ALLOWED_IDS = torch.cat( [VALID_AUDIO_IDS, torch.tensor([NEW_BLOCK_TOKEN, EOS_TOKEN])] ).to(device) MASKER = AudioLogitMask(ALLOWED_IDS) # ── 3. FastAPI ‑ Grundgerüst ───────────────────────────────────────── app = FastAPI() @app.get("/") async def ping(): return {"msg": "Orpheus‑TTS OK"} @app.on_event("startup") async def load_models(): global tok, model, snac tok = AutoTokenizer.from_pretrained(REPO) snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) model = AutoModelForCausalLM.from_pretrained( REPO, low_cpu_mem_usage=True, device_map={"": 0} if device == "cuda" else None, torch_dtype=torch.bfloat16 if device == "cuda" else None, ) model.config.pad_token_id = model.config.eos_token_id model.config.use_cache = True # ── 4. Hilfs‑Funktionen ───────────────────────────────────────────── def build_prompt(text:str, voice:str): base = f"{voice}: {text}" ids = tok(base, return_tensors="pt").input_ids.to(device) ids = torch.cat( [ torch.tensor([[START_TOKEN]], device=device), ids, torch.tensor([[128009, 128260]], device=device), ], 1, ) return ids, torch.ones_like(ids) def decode_snac(block7:list[int])->bytes: l1,l2,l3=[],[],[] b=block7 l1.append(b[0]) l2.append(b[1]-4096) l3.extend([b[2]-8192, b[3]-12288]) l2.append(b[4]-16384) l3.extend([b[5]-20480, b[6]-24576]) codes=[torch.tensor(x,device=device).unsqueeze(0) for x in (l1,l2,l3)] audio=snac.decode(codes).squeeze().cpu().numpy() return (audio*32767).astype("int16").tobytes() # ── 5. WebSocket‑Endpoint ─────────────────────────────────────────── @app.websocket("/ws/tts") async def tts(ws: WebSocket): await ws.accept() try: req = json.loads(await ws.receive_text()) text = req.get("text","") voice = req.get("voice","Jakob") ids, attn = build_prompt(text, voice) past = None buf = [] while True: out = model.generate( input_ids=ids if past is None else None, attention_mask=attn if past is None else None, past_key_values=past, max_new_tokens=CHUNK_TOKENS, logits_processor=[MASKER], do_sample=True, temperature=0.7, top_p=0.95, use_cache=True, return_dict_in_generate=True, ) past = out.past_key_values newtok = out.sequences[0,-out.num_generated_tokens:].tolist() for t in newtok: if t==EOS_TOKEN: raise StopIteration if t==NEW_BLOCK_TOKEN: buf.clear(); continue buf.append(t-AUDIO_BASE) if len(buf)==7: await ws.send_bytes(decode_snac(buf)) buf.clear() # ab jetzt nur noch mit Cache weiter‑generieren ids, attn = None, None except (StopIteration, WebSocketDisconnect): pass # normales Ende except Exception as e: print("WS‑Error:", e) if ws.client_state.name != "DISCONNECTED": await ws.close(code=1011) # Fehlercode nur, falls noch offen finally: try: if ws.client_state.name != "DISCONNECTED": await ws.close() # sauberes Close except RuntimeError: # Starlette hat bereits ein Close‑Frame verschickt pass # ── 6. Lokaler Test ───────────────────────────────────────────────── if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=7860)