Spaces:
Paused
Paused
import os | |
import json | |
import torch | |
import numpy as np | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from huggingface_hub import login | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from snac import SNAC | |
# — HF‑Token & Login (falls gesetzt) — | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if HF_TOKEN: | |
login(HF_TOKEN) | |
# — Device auswählen — | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
app = FastAPI() | |
async def read_root(): | |
return {"message": "Hello, world!"} | |
# — Globale Modelle — | |
model = None | |
tokenizer = None | |
snac_model = None | |
async def load_models(): | |
global model, tokenizer, snac_model | |
# 1) SNAC laden | |
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) | |
# 2) Orpheus‑TTS (public “natural”-Variante) | |
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" | |
tokenizer = AutoTokenizer.from_pretrained(REPO) | |
model = AutoModelForCausalLM.from_pretrained( | |
REPO, | |
device_map="auto" if device == "cuda" else None, | |
torch_dtype=torch.bfloat16 if device == "cuda" else None, | |
low_cpu_mem_usage=True | |
).to(device) | |
model.config.pad_token_id = model.config.eos_token_id | |
# — Marker und Offsets — | |
START_TOKEN = 128259 | |
END_TOKENS = [128009, 128260] | |
AUDIO_OFFSET = 128266 | |
def process_single_prompt(prompt: str, voice: str) -> list[int]: | |
# Prompt zusammenstellen | |
text = f"{voice}: {prompt}" if voice and voice != "in_prompt" else prompt | |
# Tokenize + Marker | |
ids = tokenizer(text, return_tensors="pt").input_ids.to(device) | |
start = torch.tensor([[START_TOKEN]], dtype=torch.int64, device=device) | |
end = torch.tensor([END_TOKENS], dtype=torch.int64, device=device) | |
input_ids = torch.cat([start, ids, end], dim=1) | |
attention_mask = torch.ones_like(input_ids) | |
# Generieren | |
gen = model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=4000, | |
do_sample=True, | |
temperature=0.6, | |
top_p=0.95, | |
repetition_penalty=1.1, | |
eos_token_id=128258, | |
use_cache=True, | |
) | |
# Nach letztem START_TOKEN croppen | |
token_to_find = 128257 | |
token_to_remove = 128258 | |
idxs = (gen == token_to_find).nonzero(as_tuple=True)[1] | |
if idxs.numel() > 0: | |
cropped = gen[:, idxs[-1] + 1 :] | |
else: | |
cropped = gen | |
# Padding entfernen & Länge auf Vielfaches von 7 bringen | |
row = cropped[0][cropped[0] != token_to_remove] | |
new_len = (row.size(0) // 7) * 7 | |
trimmed = row[:new_len].tolist() | |
# Offset abziehen | |
return [t - AUDIO_OFFSET for t in trimmed] | |
def redistribute_codes(code_list: list[int]) -> np.ndarray: | |
# 7er‑Blöcke auf 3 Layer verteilen | |
layer1, layer2, layer3 = [], [], [] | |
for i in range(len(code_list) // 7): | |
b = code_list[7*i : 7*i+7] | |
layer1.append(b[0]) | |
layer2.append(b[1] - 4096) | |
layer3.append(b[2] - 2*4096) | |
layer3.append(b[3] - 3*4096) | |
layer2.append(b[4] - 4*4096) | |
layer3.append(b[5] - 5*4096) | |
layer3.append(b[6] - 6*4096) | |
codes = [ | |
torch.tensor(layer1, device=device).unsqueeze(0), | |
torch.tensor(layer2, device=device).unsqueeze(0), | |
torch.tensor(layer3, device=device).unsqueeze(0), | |
] | |
audio = snac_model.decode(codes).squeeze().cpu().numpy() | |
return audio # float32 @24 kHz | |
async def tts_ws(ws: WebSocket): | |
await ws.accept() | |
try: | |
msg = await ws.receive_text() | |
req = json.loads(msg) | |
text = req.get("text", "") | |
voice = req.get("voice", "") | |
# 1) Prompt vorbereiten | |
input_ids, attention_mask = prepare_inputs(text, voice) | |
past_kvs = None | |
buffer = [] | |
# 2) Token‑für‑Token (oder in kleinen Blöcken) | |
while True: | |
# Nur max_new_tokens=50 pro Aufruf | |
out = model.generate( | |
input_ids=input_ids if past_kvs is None else None, | |
attention_mask=attention_mask if past_kvs is None else None, | |
past_key_values=past_kvs, | |
use_cache=True, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.1, | |
max_new_tokens=50, | |
eos_token_id=128258, | |
return_dict_in_generate=True, | |
output_past_key_values=True, | |
return_legacy_cache=True, # falls Ihr noch das alte past_key_values-Format braucht | |
) | |
# Extrahiere neue Token (ohne die already generated ones) | |
new_ids = out.sequences[0, input_ids.shape[-1]:].tolist() | |
past_kvs = out.past_key_values | |
for tok in new_ids: | |
if tok == model.config.eos_token_id: | |
# Stream zu Ende | |
break | |
if tok == 128257: # Reset-Start‑Marker | |
buffer = [] | |
continue | |
buffer.append(tok - AUDIO_OFFSET) | |
# Sobald wir 7 Audio‑Codes gesammelt haben → dekodieren & schicken | |
if len(buffer) == 7: | |
pcm = decode_block(buffer) | |
buffer = [] | |
await ws.send_bytes(pcm) | |
# Wenn EOS im Chunk war, abbrechen | |
if model.config.eos_token_id in new_ids: | |
break | |
# Danach weiter mit nächsten 50 Tokens, | |
# input_ids & attention_mask nur beim ersten Aufruf nötig | |
input_ids = None | |
attention_mask = None | |
# 3) Am Ende WebSocket sauber schließen | |
await ws.close() | |
except WebSocketDisconnect: | |
pass | |
except Exception as e: | |
print("Error in /ws/tts:", e) | |
await ws.close(code=1011) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app:app", host="0.0.0.0", port=7860) | |