import os import json import torch import numpy as np from fastapi import FastAPI, WebSocket, WebSocketDisconnect from huggingface_hub import login from transformers import AutoModelForCausalLM, AutoTokenizer from snac import SNAC # — HF‑Token & Login (wenn gesetzt) — HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: login(HF_TOKEN) # — Device wählen — device = "cuda" if torch.cuda.is_available() else "cpu" app = FastAPI() @app.get("/") async def read_root(): return {"message": "Hello, world!"} # — Globale Modelle — model = None tokenizer = None snac_model = None # — Startup: SNAC & Orpheus laden — @app.on_event("startup") async def load_models(): global model, tokenizer, snac_model # 1) SNAC snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) # 2) Orpheus‑TTS REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_synthetic-v0.1" tokenizer = AutoTokenizer.from_pretrained(REPO) model = AutoModelForCausalLM.from_pretrained( REPO, device_map="auto" if device=="cuda" else None, torch_dtype=torch.bfloat16 if device=="cuda" else None, low_cpu_mem_usage=True ).to(device) model.config.pad_token_id = model.config.eos_token_id # — Marker und Offsets aus der Vorlage — START_TOKEN = 128259 END_TOKENS = [128009, 128260] AUDIO_OFFSET = 128266 def process_single_prompt(prompt: str, voice: str) -> list[int]: # Prompt zusammenbauen if voice and voice != "in_prompt": text = f"{voice}: {prompt}" else: text = prompt # Tokenize + Marker ids = tokenizer(text, return_tensors="pt").input_ids start = torch.tensor([[START_TOKEN]], dtype=torch.int64) end = torch.tensor([END_TOKENS], dtype=torch.int64) input_ids = torch.cat([start, ids, end], dim=1).to(device) attention_mask = torch.ones_like(input_ids) # Generieren gen = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=4000, do_sample=True, temperature=0.6, top_p=0.95, repetition_penalty=1.1, eos_token_id=128258, use_cache=True, ) # letzten START_TOKEN finden & croppen token_to_find = 128257 token_to_remove = 128258 idxs = (gen == token_to_find).nonzero(as_tuple=True)[1] if idxs.numel() > 0: cropped = gen[:, idxs[-1] + 1 :] else: cropped = gen # Padding entfernen row = cropped[0][cropped[0] != token_to_remove] # Aus Länge ein Vielfaches von 7 machen new_len = (row.size(0) // 7) * 7 trimmed = row[:new_len].tolist() # Offset abziehen return [t - AUDIO_OFFSET for t in trimmed] def redistribute_codes(code_list: list[int]) -> np.ndarray: # Die 7er‑Blöcke auf 3 Layer verteilen und dekodieren layer1, layer2, layer3 = [], [], [] for i in range(len(code_list) // 7): b = code_list[7*i : 7*i+7] layer1.append(b[0]) layer2.append(b[1] - 4096) layer3.append(b[2] - 2*4096) layer3.append(b[3] - 3*4096) layer2.append(b[4] - 4*4096) layer3.append(b[5] - 5*4096) layer3.append(b[6] - 6*4096) codes = [ torch.tensor(layer1, device=device).unsqueeze(0), torch.tensor(layer2, device=device).unsqueeze(0), torch.tensor(layer3, device=device).unsqueeze(0), ] audio = snac_model.decode(codes).squeeze().cpu().numpy() return audio # float32 @24 kHz # — WebSocket‑Endpoint für TTS — @app.websocket("/ws/tts") async def tts_ws(ws: WebSocket): await ws.accept() try: # 1) Text + Voice empfangen msg = await ws.receive_text() req = json.loads(msg) text = req.get("text", "") voice = req.get("voice", "") # 2) Prompt → Code‑Liste with torch.no_grad(): codes = process_single_prompt(text, voice) audio_np = redistribute_codes(codes) # 3) In PCM16 konvertieren und senden pcm16 = (audio_np * 32767).astype(np.int16).tobytes() await ws.send_bytes(pcm16) # 4) sauber schließen await ws.close() except WebSocketDisconnect: pass except Exception as e: print("Error in /ws/tts:", e) await ws.close(code=1011) if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=7860)