import os import json import asyncio import torch # Bugfix für PyTorch 2.2.x Flash‑SDP‑Assertion torch.backends.cuda.enable_flash_sdp(False) from fastapi import FastAPI, WebSocket, WebSocketDisconnect from huggingface_hub import login from snac import SNAC from transformers import AutoModelForCausalLM, AutoTokenizer # — HF‑Token & Login — HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: login(HF_TOKEN) # — Device wählen — device = "cuda" if torch.cuda.is_available() else "cpu" # — FastAPI instanzieren — app = FastAPI() # — Hello‑Route, damit GET / kein 404 mehr gibt — @app.get("/") async def read_root(): return {"message": "Orpheus TTS WebSocket Server läuft"} # — Modelle beim Startup laden — @app.on_event("startup") async def load_models(): global tokenizer, model, snac # SNAC für Audio‑Decoding snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) # Orpheus‑TTS Base REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" tokenizer = AutoTokenizer.from_pretrained(REPO) model = AutoModelForCausalLM.from_pretrained( REPO, device_map={"": 0} if device=="cuda" else None, torch_dtype=torch.bfloat16 if device=="cuda" else None, low_cpu_mem_usage=True, return_legacy_cache=True # für compatibility mit past_key_values als Tuple ).to(device) model.config.pad_token_id = model.config.eos_token_id # --- Logit‑Masking vorbereiten --- # reine Audio‑Tokens laufen von 128266 bis 128266+4096-1 AUDIO_OFFSET = 128266 AUDIO_COUNT = 4096 valid_audio = torch.arange(AUDIO_OFFSET, AUDIO_OFFSET + AUDIO_COUNT, device=device) ctrl_tokens = torch.tensor([128257, model.config.eos_token_id], device=device) global ALLOWED_IDS ALLOWED_IDS = torch.cat([valid_audio, ctrl_tokens]) def sample_from_logits(logits: torch.Tensor) -> int: """ Maskt alle IDs außer ALLOWED_IDS und sampelt dann einen Token. """ # logits: [1, vocab_size] mask = torch.full_like(logits, float("-inf")) mask[0, ALLOWED_IDS] = 0.0 probs = torch.softmax(logits + mask, dim=-1) return torch.multinomial(probs, num_samples=1).item() def prepare_inputs(text: str, voice: str): prompt = f"{voice}: {text}" ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) # Start‐/End‐Marker start = torch.tensor([[128259]], dtype=torch.int64, device=device) end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device) input_ids = torch.cat([start, ids, end], dim=1) attention_mask = torch.ones_like(input_ids, device=device) return input_ids, attention_mask def decode_block(block: list[int]) -> bytes: """ Aus 7 gesampelten Audio‑Codes einen PCM‑16‑Byte‐Block dekodieren. Hier erwarten wir block[i] = raw_token - 128266. """ layer1, layer2, layer3 = [], [], [] b = block layer1.append(b[0]) layer2.append(b[1] - 4096) layer3.append(b[2] - 2*4096) layer3.append(b[3] - 3*4096) layer2.append(b[4] - 4*4096) layer3.append(b[5] - 5*4096) layer3.append(b[6] - 6*4096) dev = next(snac.parameters()).device codes = [ torch.tensor(layer1, device=dev).unsqueeze(0), torch.tensor(layer2, device=dev).unsqueeze(0), torch.tensor(layer3, device=dev).unsqueeze(0), ] audio = snac.decode(codes).squeeze().cpu().numpy() # in PCM16 umwandeln pcm16 = (audio * 32767).astype("int16").tobytes() return pcm16 # — WebSocket Endpoint für TTS Streaming — @app.websocket("/ws/tts") async def tts_ws(ws: WebSocket): await ws.accept() try: msg = await ws.receive_text() req = json.loads(msg) text = req.get("text", "") voice = req.get("voice", "Jakob") # Inputs vorbereiten input_ids, attention_mask = prepare_inputs(text, voice) past_kvs = None buffer = [] # sammelt die 7 Audio‑Codes # Token‑für‑Token Loop while True: out = model( input_ids=input_ids if past_kvs is None else None, attention_mask=attention_mask if past_kvs is None else None, past_key_values=past_kvs, use_cache=True, return_dict=True ) past_kvs = out.past_key_values next_tok = sample_from_logits(out.logits[:, -1, :]) # Ende? if next_tok == model.config.eos_token_id: break # Reset bei neuem Audio‑Block‑Start if next_tok == 128257: buffer.clear() input_ids = torch.tensor([[next_tok]], device=device) attention_mask = torch.ones_like(input_ids) continue # Audio‑Code sammeln (Offset abziehen) buffer.append(next_tok - 128266) # sobald wir 7 Codes haben → dekodieren & senden if len(buffer) == 7: pcm = decode_block(buffer) buffer.clear() await ws.send_bytes(pcm) # nächster Schritt: genau diesen Token wieder einspeisen input_ids = torch.tensor([[next_tok]], device=device) attention_mask = torch.ones_like(input_ids) # sauber beenden await ws.close() except WebSocketDisconnect: pass except Exception as e: print("Error in /ws/tts:", e) await ws.close(code=1011) # — CLI zum lokalen Testen — if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=7860)