Spaces:
Paused
Paused
import os | |
import json | |
import asyncio | |
import torch | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from huggingface_hub import login | |
from snac import SNAC | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# — HF‑Token & Login — | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if HF_TOKEN: | |
login(HF_TOKEN) | |
# — Device wählen — | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# — FastAPI instanziieren — | |
app = FastAPI() | |
# — Hello‑Route, damit GET / nicht 404 wirft — | |
async def read_root(): | |
return {"message": "Hello, world!"} | |
# — Modelle bei Startup laden — | |
async def load_models(): | |
global tokenizer, model, snac | |
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) | |
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" | |
tokenizer = AutoTokenizer.from_pretrained(REPO) | |
model = AutoModelForCausalLM.from_pretrained( | |
REPO, | |
device_map="auto", | |
torch_dtype=torch.bfloat16 if device == "cuda" else None, | |
low_cpu_mem_usage=True | |
) | |
# Für pad-token fallback auf eos | |
model.config.pad_token_id = model.config.eos_token_id | |
# — Hilfsfunktionen — | |
START_TOKEN = 128259 | |
END_TOKENS = [128009, 128260] | |
RESET_TOKEN = 128257 | |
AUDIO_OFFSET = 128266 | |
EOS_TOKEN = model.config.eos_token_id if 'model' in globals() else 128258 | |
def prepare_inputs(text: str, voice: str): | |
prompt = f"{voice}: {text}" | |
ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
start = torch.tensor([[START_TOKEN]], device=device) | |
end = torch.tensor([END_TOKENS], device=device) | |
input_ids = torch.cat([start, ids, end], dim=1) | |
attention_mask = torch.ones_like(input_ids) | |
return input_ids, attention_mask | |
def decode_block(block: list[int]): | |
# aus genau 7 Audio‑Codes ein PCM‑Byte‑Block bauen | |
l1, l2, l3 = [], [], [] | |
b = block | |
l1.append(b[0]) | |
l2.append(b[1] - 4096) | |
l3.append(b[2] - 2*4096) | |
l3.append(b[3] - 3*4096) | |
l2.append(b[4] - 4*4096) | |
l3.append(b[5] - 5*4096) | |
l3.append(b[6] - 6*4096) | |
codes = [ | |
torch.tensor(l1, device=device).unsqueeze(0), | |
torch.tensor(l2, device=device).unsqueeze(0), | |
torch.tensor(l3, device=device).unsqueeze(0), | |
] | |
audio = snac.decode(codes).squeeze().cpu().numpy() | |
return (audio * 32767).astype("int16").tobytes() | |
# — WebSocket‑Endpoint für TTS Streaming — | |
async def tts_ws(ws: WebSocket): | |
await ws.accept() | |
try: | |
msg = await ws.receive_text() | |
req = json.loads(msg) | |
text = req.get("text", "") | |
voice = req.get("voice", "Jakob") | |
input_ids, attention_mask = prepare_inputs(text, voice) | |
past_kvs = None | |
collected = [] | |
# Token‑für‑Token mit eigener Sampling‑Schleife | |
while True: | |
out = model( | |
input_ids=input_ids if past_kvs is None else None, | |
attention_mask=attention_mask if past_kvs is None else None, | |
past_key_values=past_kvs, | |
use_cache=True, | |
) | |
logits = out.logits[:, -1, :] | |
past_kvs = out.past_key_values | |
# Sampling | |
probs = torch.softmax(logits, dim=-1) | |
nxt = torch.multinomial(probs, num_samples=1).item() | |
# EOS → fertig | |
if nxt == EOS_TOKEN: | |
break | |
# RESET → alte Sammlung verwerfen | |
if nxt == RESET_TOKEN: | |
collected = [] | |
# und input_ids für nächsten Durchlauf auf None setzen | |
input_ids = None | |
attention_mask = None | |
continue | |
# Audio‑Code abziehen & sammeln | |
collected.append(nxt - AUDIO_OFFSET) | |
# jede 7 Codes → dekodieren & streamen | |
if len(collected) == 7: | |
pcm = decode_block(collected) | |
collected = [] | |
await ws.send_bytes(pcm) | |
# nur beim allerersten Schritt mit IDs arbeiten | |
input_ids = None | |
attention_mask = None | |
# Stream sauber beenden | |
await ws.close() | |
except WebSocketDisconnect: | |
# Client hat Disconnect gemacht → nichts tun | |
pass | |
except Exception as e: | |
# auf Fehler 1011 senden | |
print("Error in /ws/tts:", e) | |
await ws.close(code=1011) | |