Spaces:
Paused
Paused
import os | |
import json | |
import asyncio | |
import torch | |
# Bugfix für PyTorch 2.2.x Flash‑SDP‑Assertion | |
torch.backends.cuda.enable_flash_sdp(False) | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from huggingface_hub import login | |
from snac import SNAC | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# — HF‑Token & Login — | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if HF_TOKEN: | |
login(HF_TOKEN) | |
# — Device wählen — | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# — FastAPI instanzieren — | |
app = FastAPI() | |
# — Hello‑Route, damit GET / kein 404 mehr gibt — | |
async def read_root(): | |
return {"message": "Orpheus TTS WebSocket Server läuft"} | |
# — Modelle beim Startup laden — | |
async def load_models(): | |
global tokenizer, model, snac | |
# SNAC für Audio‑Decoding | |
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) | |
# Orpheus‑TTS Base | |
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" | |
tokenizer = AutoTokenizer.from_pretrained(REPO) | |
model = AutoModelForCausalLM.from_pretrained( | |
REPO, | |
device_map={"": 0} if device=="cuda" else None, | |
torch_dtype=torch.bfloat16 if device=="cuda" else None, | |
low_cpu_mem_usage=True, | |
return_legacy_cache=True # für compatibility mit past_key_values als Tuple | |
).to(device) | |
model.config.pad_token_id = model.config.eos_token_id | |
# --- Logit‑Masking vorbereiten --- | |
# reine Audio‑Tokens laufen von 128266 bis 128266+4096-1 | |
AUDIO_OFFSET = 128266 | |
AUDIO_COUNT = 4096 | |
valid_audio = torch.arange(AUDIO_OFFSET, AUDIO_OFFSET + AUDIO_COUNT, device=device) | |
ctrl_tokens = torch.tensor([128257, model.config.eos_token_id], device=device) | |
global ALLOWED_IDS | |
ALLOWED_IDS = torch.cat([valid_audio, ctrl_tokens]) | |
def sample_from_logits(logits: torch.Tensor) -> int: | |
""" | |
Maskt alle IDs außer ALLOWED_IDS und sampelt dann einen Token. | |
""" | |
# logits: [1, vocab_size] | |
mask = torch.full_like(logits, float("-inf")) | |
mask[0, ALLOWED_IDS] = 0.0 | |
probs = torch.softmax(logits + mask, dim=-1) | |
return torch.multinomial(probs, num_samples=1).item() | |
def prepare_inputs(text: str, voice: str): | |
prompt = f"{voice}: {text}" | |
ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
# Start‐/End‐Marker | |
start = torch.tensor([[128259]], dtype=torch.int64, device=device) | |
end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device) | |
input_ids = torch.cat([start, ids, end], dim=1) | |
attention_mask = torch.ones_like(input_ids, device=device) | |
return input_ids, attention_mask | |
def decode_block(block: list[int]) -> bytes: | |
""" | |
Aus 7 gesampelten Audio‑Codes einen PCM‑16‑Byte‐Block dekodieren. | |
Hier erwarten wir block[i] = raw_token - 128266. | |
""" | |
layer1, layer2, layer3 = [], [], [] | |
b = block | |
layer1.append(b[0]) | |
layer2.append(b[1] - 4096) | |
layer3.append(b[2] - 2*4096) | |
layer3.append(b[3] - 3*4096) | |
layer2.append(b[4] - 4*4096) | |
layer3.append(b[5] - 5*4096) | |
layer3.append(b[6] - 6*4096) | |
dev = next(snac.parameters()).device | |
codes = [ | |
torch.tensor(layer1, device=dev).unsqueeze(0), | |
torch.tensor(layer2, device=dev).unsqueeze(0), | |
torch.tensor(layer3, device=dev).unsqueeze(0), | |
] | |
audio = snac.decode(codes).squeeze().cpu().numpy() | |
# in PCM16 umwandeln | |
pcm16 = (audio * 32767).astype("int16").tobytes() | |
return pcm16 | |
# — WebSocket Endpoint für TTS Streaming — | |
async def tts_ws(ws: WebSocket): | |
await ws.accept() | |
try: | |
msg = await ws.receive_text() | |
req = json.loads(msg) | |
text = req.get("text", "") | |
voice = req.get("voice", "Jakob") | |
# Inputs vorbereiten | |
input_ids, attention_mask = prepare_inputs(text, voice) | |
past_kvs = None | |
buffer = [] # sammelt die 7 Audio‑Codes | |
# Token‑für‑Token Loop | |
while True: | |
out = model( | |
input_ids=input_ids if past_kvs is None else None, | |
attention_mask=attention_mask if past_kvs is None else None, | |
past_key_values=past_kvs, | |
use_cache=True, | |
return_dict=True | |
) | |
past_kvs = out.past_key_values | |
next_tok = sample_from_logits(out.logits[:, -1, :]) | |
# Ende? | |
if next_tok == model.config.eos_token_id: | |
break | |
# Reset bei neuem Audio‑Block‑Start | |
if next_tok == 128257: | |
buffer.clear() | |
input_ids = torch.tensor([[next_tok]], device=device) | |
attention_mask = torch.ones_like(input_ids) | |
continue | |
# Audio‑Code sammeln (Offset abziehen) | |
buffer.append(next_tok - 128266) | |
# sobald wir 7 Codes haben → dekodieren & senden | |
if len(buffer) == 7: | |
pcm = decode_block(buffer) | |
buffer.clear() | |
await ws.send_bytes(pcm) | |
# nächster Schritt: genau diesen Token wieder einspeisen | |
input_ids = torch.tensor([[next_tok]], device=device) | |
attention_mask = torch.ones_like(input_ids) | |
# sauber beenden | |
await ws.close() | |
except WebSocketDisconnect: | |
pass | |
except Exception as e: | |
print("Error in /ws/tts:", e) | |
await ws.close(code=1011) | |
# — CLI zum lokalen Testen — | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app:app", host="0.0.0.0", port=7860) | |