Tomtom84's picture
Update app.py
fd06e70 verified
raw
history blame
4.84 kB
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
# — ENV & AUTH —
HF_TOKEN = os.getenv("HF_TOKEN", "")
if HF_TOKEN:
login(HF_TOKEN)
# — DEVICE SETUP —
device = "cuda" if torch.cuda.is_available() else "cpu"
# — FASTAPI INSTANCE —
app = FastAPI()
# — HEALTHCHECK / ROOT —
@app.get("/")
async def read_root():
return {"message": "TTS WebSocket up and running!"}
# — LOAD MODELS ON STARTUP —
@app.on_event("startup")
async def startup_event():
global tokenizer, model, snac
# 1) SNAC vocoder
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
# 2) TTS model & tokenizer
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16 if device == "cuda" else None,
low_cpu_mem_usage=True
)
# make pad == eos
model.config.pad_token_id = model.config.eos_token_id
# — HELPERS —
START_TOKEN = 128259
END_TOKENS = [128009, 128260]
RESET_MARKER = 128257
EOS_TOKEN = 128258
AUDIO_TOKEN_OFFSET = 128266 # to subtract from token→audio code
def prepare_inputs(text: str, voice: str):
prompt = f"{voice}: {text}"
in_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
start = torch.tensor([[START_TOKEN]], dtype=torch.int64, device=device)
end = torch.tensor([END_TOKENS], dtype=torch.int64, device=device)
ids = torch.cat([start, in_ids, end], dim=1)
mask = torch.ones_like(ids)
return ids, mask
def decode_seven(tokens: list[int]) -> bytes:
"""Take exactly 7 audio‑codes, build SNAC input and decode to PCM16 bytes."""
b = tokens
l1 = [ b[0] ]
l2 = [ b[1] - 1*4096, b[4] - 4*4096 ]
l3 = [ b[2] - 2*4096, b[3] - 3*4096, b[5] - 5*4096, b[6] - 6*4096 ]
codes = [
torch.tensor(l1, device=device).unsqueeze(0),
torch.tensor(l2, device=device).unsqueeze(0),
torch.tensor(l3, device=device).unsqueeze(0),
]
audio = snac.decode(codes).squeeze().cpu().numpy()
pcm16 = (audio * 32767).astype("int16").tobytes()
return pcm16
# — WEBSOCKET ENDPOINT —
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
# 1) receive JSON request
msg = await ws.receive_text()
req = json.loads(msg)
text = req.get("text", "")
voice = req.get("voice", "Jakob")
# 2) prepare prompt
input_ids, attention_mask = prepare_inputs(text, voice)
prompt_len = input_ids.size(1)
# 3) chunked generation setup
past_kvs = None
buffer: list[int] = []
generated_offset = 0
while True:
# 4) generate up to 50 new tokens at a time
out = model.generate(
input_ids= input_ids if past_kvs is None else None,
attention_mask=attention_mask if past_kvs is None else None,
max_new_tokens=50,
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.1,
eos_token_id=EOS_TOKEN,
pad_token_id=EOS_TOKEN,
use_cache=True,
return_dict_in_generate=False,
return_legacy_cache=True,
past_key_values=past_kvs,
)
# out is a tuple: (generated_ids, new_past_kvs)
gen_ids, past_kvs = out
# 5) extract only newly generated tokens
seq = gen_ids[0]
new_seq = seq[prompt_len + generated_offset :]
generated_offset += new_seq.size(0)
# 6) process each new token
stop = False
for t in new_seq.tolist():
if t == EOS_TOKEN:
stop = True
break
if t == RESET_MARKER:
buffer.clear()
continue
# convert to audio-code
buffer.append(t - AUDIO_TOKEN_OFFSET)
# once we have 7 codes, decode & stream
if len(buffer) >= 7:
block = buffer[:7]
buffer = buffer[7:]
pcm_bytes = decode_seven(block)
await ws.send_bytes(pcm_bytes)
if stop:
break
# 7) clean close
await ws.close()
except WebSocketDisconnect:
pass
except Exception as e:
print("Error in /ws/tts:", e)
await ws.close(code=1011)