Spaces:
Paused
Paused
File size: 4,324 Bytes
2c15189 0316ec3 4189fe1 9bf14d0 a09ea48 0316ec3 9bf14d0 2c15189 9bf14d0 2008a3f 9bf14d0 1ab029d 0316ec3 9bf14d0 0dfc310 9bf14d0 3281189 67c3132 9bf14d0 a8606ac 2c15189 a09ea48 4189fe1 9bf14d0 2c15189 9bf14d0 d4630a2 2c15189 9bf14d0 2c15189 4189fe1 9bf14d0 a09ea48 9bf14d0 2c15189 a09ea48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
# — HF‑Token & Login —
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
# — Device wählen —
device = "cuda" if torch.cuda.is_available() else "cpu"
# — FastAPI instanziieren —
app = FastAPI()
# — Hello‑Route, damit kein 404 bei GET / mehr kommt —
@app.get("/")
async def read_root():
return {"message": "Hello, world!"}
# — Modelle bei Startup laden —
@app.on_event("startup")
async def load_models():
global tokenizer, model, snac
# SNAC laden
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
# TTS‑Modell laden
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map={"": 0} if device == "cuda" else None,
torch_dtype=torch.bfloat16 if device == "cuda" else None,
low_cpu_mem_usage=True
)
# Pad‑ID auf EOS einstellen
model.config.pad_token_id = model.config.eos_token_id
# — Hilfsfunktionen —
def prepare_inputs(text: str, voice: str):
prompt = f"{voice}: {text}"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Start‑/End‑Marker
start = torch.tensor([[128259]], dtype=torch.int64, device=device)
end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
ids = torch.cat([start, input_ids, end], dim=1)
mask = torch.ones_like(ids)
return ids, mask
def decode_block(block_tokens: list[int]):
# aus 7 Tokens einen SNAC‑Decode‑Block bauen
layer1, layer2, layer3 = [], [], []
b = block_tokens
layer1.append(b[0])
layer2.append(b[1] - 4096)
layer3.append(b[2] - 2*4096)
layer3.append(b[3] - 3*4096)
layer2.append(b[4] - 4*4096)
layer3.append(b[5] - 5*4096)
layer3.append(b[6] - 6*4096)
codes = [
torch.tensor(layer1, device=device).unsqueeze(0),
torch.tensor(layer2, device=device).unsqueeze(0),
torch.tensor(layer3, device=device).unsqueeze(0),
]
# ergibt FloatTensor shape (1, N), @24 kHz
audio = snac.decode(codes).squeeze().cpu().numpy()
# in PCM16 umwandeln
return (audio * 32767).astype("int16").tobytes()
# — WebSocket Endpoint für TTS Streaming —
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
# erst die Anfrage als JSON empfangen
msg = await ws.receive_text()
req = json.loads(msg)
text = req.get("text", "")
voice = req.get("voice", "Jakob")
# Inputs bauen
input_ids, attention_mask = prepare_inputs(text, voice)
past_kvs = None
collected = []
# Token‑für‑Token loop
while True:
out = model(
input_ids=input_ids if past_kvs is None else None,
attention_mask=attention_mask if past_kvs is None else None,
past_key_values=past_kvs,
use_cache=True,
)
logits = out.logits[:, -1, :]
past_kvs = out.past_key_values
# Sampling
probs = torch.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1).item()
# Ende, wenn EOS
if nxt == model.config.eos_token_id:
break
# Reset bei neuem Start‑Marker
if nxt == 128257:
collected = []
continue
# Audio‑Code offsetten und sammeln
collected.append(nxt - 128266)
# sobald 7 Stück, direkt dekodieren und senden
if len(collected) == 7:
pcm = decode_block(collected)
collected = []
await ws.send_bytes(pcm)
# nach Ende sauber schließen
await ws.close()
except WebSocketDisconnect:
# Client hat disconnectet
pass
except Exception as e:
# bei Fehlern 1011 senden
print("Error in /ws/tts:", e)
await ws.close(code=1011)
|