import os import json import asyncio import torch from fastapi import FastAPI, WebSocket, WebSocketDisconnect from huggingface_hub import login from snac import SNAC from transformers import AutoModelForCausalLM, AutoTokenizer # — ENV & AUTH — HF_TOKEN = os.getenv("HF_TOKEN", "") if HF_TOKEN: login(HF_TOKEN) # — DEVICE SETUP — device = "cuda" if torch.cuda.is_available() else "cpu" # — FASTAPI INSTANCE — app = FastAPI() # — HEALTHCHECK / ROOT — @app.get("/") async def read_root(): return {"message": "TTS WebSocket up and running!"} # — LOAD MODELS ON STARTUP — @app.on_event("startup") async def startup_event(): global tokenizer, model, snac # 1) SNAC vocoder snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) # 2) TTS model & tokenizer model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", torch_dtype=torch.bfloat16 if device == "cuda" else None, low_cpu_mem_usage=True ) # make pad == eos model.config.pad_token_id = model.config.eos_token_id # — HELPERS — START_TOKEN = 128259 END_TOKENS = [128009, 128260] RESET_MARKER = 128257 EOS_TOKEN = 128258 AUDIO_TOKEN_OFFSET = 128266 # to subtract from token→audio code def prepare_inputs(text: str, voice: str): prompt = f"{voice}: {text}" in_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) start = torch.tensor([[START_TOKEN]], dtype=torch.int64, device=device) end = torch.tensor([END_TOKENS], dtype=torch.int64, device=device) ids = torch.cat([start, in_ids, end], dim=1) mask = torch.ones_like(ids) return ids, mask def decode_seven(tokens: list[int]) -> bytes: """Take exactly 7 audio‑codes, build SNAC input and decode to PCM16 bytes.""" b = tokens l1 = [ b[0] ] l2 = [ b[1] - 1*4096, b[4] - 4*4096 ] l3 = [ b[2] - 2*4096, b[3] - 3*4096, b[5] - 5*4096, b[6] - 6*4096 ] codes = [ torch.tensor(l1, device=device).unsqueeze(0), torch.tensor(l2, device=device).unsqueeze(0), torch.tensor(l3, device=device).unsqueeze(0), ] audio = snac.decode(codes).squeeze().cpu().numpy() pcm16 = (audio * 32767).astype("int16").tobytes() return pcm16 # — WEBSOCKET ENDPOINT — @app.websocket("/ws/tts") async def tts_ws(ws: WebSocket): await ws.accept() try: # 1) receive JSON request msg = await ws.receive_text() req = json.loads(msg) text = req.get("text", "") voice = req.get("voice", "Jakob") # 2) prepare prompt input_ids, attention_mask = prepare_inputs(text, voice) prompt_len = input_ids.size(1) # 3) chunked generation setup past_kvs = None buffer: list[int] = [] generated_offset = 0 while True: # 4) generate up to 50 new tokens at a time out = model.generate( input_ids= input_ids if past_kvs is None else None, attention_mask=attention_mask if past_kvs is None else None, max_new_tokens=50, do_sample=True, temperature=0.7, top_p=0.95, repetition_penalty=1.1, eos_token_id=EOS_TOKEN, pad_token_id=EOS_TOKEN, use_cache=True, return_dict_in_generate=False, return_legacy_cache=True, past_key_values=past_kvs, ) # out is a tuple: (generated_ids, new_past_kvs) gen_ids, past_kvs = out # 5) extract only newly generated tokens seq = gen_ids[0] new_seq = seq[prompt_len + generated_offset :] generated_offset += new_seq.size(0) # 6) process each new token stop = False for t in new_seq.tolist(): if t == EOS_TOKEN: stop = True break if t == RESET_MARKER: buffer.clear() continue # convert to audio-code buffer.append(t - AUDIO_TOKEN_OFFSET) # once we have 7 codes, decode & stream if len(buffer) >= 7: block = buffer[:7] buffer = buffer[7:] pcm_bytes = decode_seven(block) await ws.send_bytes(pcm_bytes) if stop: break # 7) clean close await ws.close() except WebSocketDisconnect: pass except Exception as e: print("Error in /ws/tts:", e) await ws.close(code=1011)