import os import json import asyncio import torch from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import PlainTextResponse from dotenv import load_dotenv from snac import SNAC from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel # — ENV & HF‑AUTH — load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: # automatisch über huggingface-cli eingeloggt os.environ["HUGGINGFACE_HUB_TOKEN"] = HF_TOKEN # — FastAPI → app = FastAPI() @app.get("/") async def hello(): return PlainTextResponse("Hallo Welt!") # — Device konfigurieren — device = "cuda" if torch.cuda.is_available() else "cpu" # — SNAC laden — print("Loading SNAC model…") snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) # — Orpheus/Kartoffel‑3B über PEFT laden — model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" print(f"Loading base LM + PEFT from {model_name}…") base = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", torch_dtype=torch.bfloat16, ) model = PeftModel.from_pretrained( base, model_name, device_map="auto", ) model.eval() tokenizer = AutoTokenizer.from_pretrained(model_name) # sicherstellen, dass pad_token_id gesetzt ist model.config.pad_token_id = model.config.eos_token_id # — Hilfsfunktionen — def prepare_prompt(text: str, voice: str): """Setzt Start‑ und End‑Marker um den eigentlichen Prompt.""" if voice: full = f"{voice}: {text}" else: full = text start = torch.tensor([[128259]], dtype=torch.int64) # BOS für Audio end = torch.tensor([[128009, 128260]], dtype=torch.int64) # ggf. Speaker‑ID + Marker enc = tokenizer(full, return_tensors="pt").input_ids seq = torch.cat([start, enc, end], dim=1).to(device) mask = torch.ones_like(seq).to(device) return seq, mask def extract_audio_tokens(generated: torch.LongTensor): """Croppe alles bis zum echten Audio-Start, entferne EOS und mache 7er-Batches.""" bos_tok = 128257 eos_tok = 128258 # letzten Start‑Token finden und ab da weiter idxs = (generated == bos_tok).nonzero(as_tuple=True)[1] if idxs.numel() > 0: cut = idxs[-1].item() + 1 cropped = generated[:, cut:] else: cropped = generated # EOS‑Marker entfernen flat = cropped[0][cropped[0] != eos_tok] # nur ein Vielfaches von 7 behalten length = (flat.size(0) // 7) * 7 flat = flat[:length] # Die Audio‑Token beginnen ab Offset 128266 return [(t.item() - 128266) for t in flat] def decode_and_stream(tokens: list[int], ws: WebSocket): """Wandelt 7er‑Gruppen in Wave‑Samples um und streamt in 0.1 s Chunks.""" # gruppiere nach 7 und dekodiere jeweils pcm16 = bytearray() offset = 0 while offset + 7 <= len(tokens): block = tokens[offset:offset+7] offset += 7 # SNAC‑Input vorbereiten # Layer‑1: direkt, Layer‑2/3 mit Offsets l1, l2, l3 = [], [], [] l1.append(block[0]) l2.append(block[1] - 4096) l3.append(block[2] - 2*4096) l3.append(block[3] - 3*4096) l2.append(block[4] - 4*4096) l3.append(block[5] - 5*4096) l3.append(block[6] - 6*4096) t1 = torch.tensor(l1, device=device).unsqueeze(0) t2 = torch.tensor(l2, device=device).unsqueeze(0) t3 = torch.tensor(l3, device=device).unsqueeze(0) audio = snac.decode([t1, t2, t3]).squeeze().cpu().numpy() # in PCM16 @24 kHz pcm = (audio * 32767).astype("int16").tobytes() pcm16.extend(pcm) # in 0.1 s‑Chunks (2400 Samples ×2 Bytes) chunk_size = 2400 * 2 for i in range(0, len(pcm16), chunk_size): ws.send_bytes(pcm16[i : i+chunk_size]) # ohne Pause kann das WebSocket überlastet werden asyncio.sleep(0.1) # — WebSocket TTS Endpoint — @app.websocket("/ws/tts") async def tts_ws(ws: WebSocket): await ws.accept() try: while True: raw = await ws.receive_text() req = json.loads(raw) text = req.get("text", "") voice = req.get("voice", "") # Prompt vorbereiten ids, mask = prepare_prompt(text, voice) # Audio‑Token generieren gen = model.generate( input_ids=ids, attention_mask=mask, max_new_tokens=4000, do_sample=True, temperature=0.7, top_p=0.95, repetition_penalty=1.1, eos_token_id=128258, forced_bos_token_id=128259, use_cache=True, ) codes = extract_audio_tokens(gen) # stream synchron await decode_and_stream(codes, ws) # sauber schließen await ws.close(code=1000) break except WebSocketDisconnect: print("Client disconnected") except Exception as e: print("Error in /ws/tts:", e) await ws.close(code=1011) # — Lokal starten — if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=7860)