Spaces:
Paused
Paused
import os | |
import json | |
import asyncio | |
import torch | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from huggingface_hub import login | |
from snac import SNAC | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# — ENV & AUTH — | |
HF_TOKEN = os.getenv("HF_TOKEN", "") | |
if HF_TOKEN: | |
login(HF_TOKEN) | |
# — DEVICE SETUP — | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# — FASTAPI INSTANCE — | |
app = FastAPI() | |
# — HEALTHCHECK / ROOT — | |
async def read_root(): | |
return {"message": "TTS WebSocket up and running!"} | |
# — LOAD MODELS ON STARTUP — | |
async def startup_event(): | |
global tokenizer, model, snac | |
# 1) SNAC vocoder | |
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) | |
# 2) TTS model & tokenizer | |
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
torch_dtype=torch.bfloat16 if device == "cuda" else None, | |
low_cpu_mem_usage=True | |
) | |
# make pad == eos | |
model.config.pad_token_id = model.config.eos_token_id | |
# — HELPERS — | |
START_TOKEN = 128259 | |
END_TOKENS = [128009, 128260] | |
RESET_MARKER = 128257 | |
EOS_TOKEN = 128258 | |
AUDIO_TOKEN_OFFSET = 128266 # to subtract from token→audio code | |
def prepare_inputs(text: str, voice: str): | |
prompt = f"{voice}: {text}" | |
in_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
start = torch.tensor([[START_TOKEN]], dtype=torch.int64, device=device) | |
end = torch.tensor([END_TOKENS], dtype=torch.int64, device=device) | |
ids = torch.cat([start, in_ids, end], dim=1) | |
mask = torch.ones_like(ids) | |
return ids, mask | |
def decode_seven(tokens: list[int]) -> bytes: | |
"""Take exactly 7 audio‑codes, build SNAC input and decode to PCM16 bytes.""" | |
b = tokens | |
l1 = [ b[0] ] | |
l2 = [ b[1] - 1*4096, b[4] - 4*4096 ] | |
l3 = [ b[2] - 2*4096, b[3] - 3*4096, b[5] - 5*4096, b[6] - 6*4096 ] | |
codes = [ | |
torch.tensor(l1, device=device).unsqueeze(0), | |
torch.tensor(l2, device=device).unsqueeze(0), | |
torch.tensor(l3, device=device).unsqueeze(0), | |
] | |
audio = snac.decode(codes).squeeze().cpu().numpy() | |
pcm16 = (audio * 32767).astype("int16").tobytes() | |
return pcm16 | |
# — WEBSOCKET ENDPOINT — | |
async def tts_ws(ws: WebSocket): | |
await ws.accept() | |
try: | |
# 1) receive JSON request | |
msg = await ws.receive_text() | |
req = json.loads(msg) | |
text = req.get("text", "") | |
voice = req.get("voice", "Jakob") | |
# 2) prepare prompt | |
input_ids, attention_mask = prepare_inputs(text, voice) | |
prompt_len = input_ids.size(1) | |
# 3) chunked generation setup | |
past_kvs = None | |
buffer: list[int] = [] | |
generated_offset = 0 | |
while True: | |
# 4) generate up to 50 new tokens at a time | |
out = model.generate( | |
input_ids= input_ids if past_kvs is None else None, | |
attention_mask=attention_mask if past_kvs is None else None, | |
max_new_tokens=50, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.1, | |
eos_token_id=EOS_TOKEN, | |
pad_token_id=EOS_TOKEN, | |
use_cache=True, | |
return_dict_in_generate=False, | |
return_legacy_cache=True, | |
past_key_values=past_kvs, | |
) | |
# out is a tuple: (generated_ids, new_past_kvs) | |
gen_ids, past_kvs = out | |
# 5) extract only newly generated tokens | |
seq = gen_ids[0] | |
new_seq = seq[prompt_len + generated_offset :] | |
generated_offset += new_seq.size(0) | |
# 6) process each new token | |
stop = False | |
for t in new_seq.tolist(): | |
if t == EOS_TOKEN: | |
stop = True | |
break | |
if t == RESET_MARKER: | |
buffer.clear() | |
continue | |
# convert to audio-code | |
buffer.append(t - AUDIO_TOKEN_OFFSET) | |
# once we have 7 codes, decode & stream | |
if len(buffer) >= 7: | |
block = buffer[:7] | |
buffer = buffer[7:] | |
pcm_bytes = decode_seven(block) | |
await ws.send_bytes(pcm_bytes) | |
if stop: | |
break | |
# 7) clean close | |
await ws.close() | |
except WebSocketDisconnect: | |
pass | |
except Exception as e: | |
print("Error in /ws/tts:", e) | |
await ws.close(code=1011) | |