Spaces:
Paused
Paused
File size: 3,947 Bytes
a09ea48 0316ec3 4189fe1 0316ec3 a09ea48 0316ec3 a09ea48 2008a3f a09ea48 0316ec3 a09ea48 0316ec3 a09ea48 0316ec3 a09ea48 0316ec3 a09ea48 ad94d02 a09ea48 0316ec3 a09ea48 0316ec3 a09ea48 0316ec3 a09ea48 0316ec3 a09ea48 97006e1 a09ea48 4189fe1 a8606ac a09ea48 4189fe1 a09ea48 4189fe1 a09ea48 4189fe1 a09ea48 4189fe1 a09ea48 4189fe1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
# — Environment & HF‑Auth —
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
# — Device & Modelle laden —
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Loading SNAC model...")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model_name = "canopylabs/3b-de-ft-research_release"
print("Loading Orpheus model...")
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16
).to(device)
model.config.pad_token_id = model.config.eos_token_id
tokenizer = AutoTokenizer.from_pretrained(model_name)
# — Hilfsfunktionen —
def process_prompt(text: str, voice: str):
prompt = f"{voice}: {text}"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
start = torch.tensor([[128259]], dtype=torch.int64)
end = torch.tensor([[128009, 128260]], dtype=torch.int64)
ids = torch.cat([start, input_ids, end], dim=1).to(device)
mask = torch.ones_like(ids).to(device)
return ids, mask
def parse_output(generated_ids: torch.LongTensor):
token_to_find = 128257
token_to_remove = 128258
idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
if idxs.numel() > 0:
last = idxs[-1].item()
cropped = generated_ids[:, last+1:]
else:
cropped = generated_ids
# remove padding token markers
rows = []
for row in cropped:
row = row[row != token_to_remove]
rows.append(row)
flat = rows[0].tolist()
# adjust and regroup
layer1, layer2, layer3 = [], [], []
for i in range(len(flat)//7):
base = flat[7*i:7*i+7]
layer1.append(base[0])
layer2.append(base[1]-4096)
layer3.extend([base[2]-(2*4096), base[3]-(3*4096)])
layer2.append(base[4]-4*4096)
layer3.extend([base[5]-(5*4096), base[6]-(6*4096)])
codes = [
torch.tensor(layer1, device=device).unsqueeze(0),
torch.tensor(layer2, device=device).unsqueeze(0),
torch.tensor(layer3, device=device).unsqueeze(0),
]
audio = snac.decode(codes).detach().squeeze().cpu().numpy()
return audio # float32 numpy at 24000 Hz
# — FastAPI + WebSocket-Endpoint —
app = FastAPI()
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
while True:
msg = await ws.receive_text()
data = json.loads(msg)
text = data.get("text", "")
voice = data.get("voice", "jana")
# Generate tokens
ids, mask = process_prompt(text, voice)
with torch.no_grad():
gen_ids = model.generate(
input_ids=ids,
attention_mask=mask,
max_new_tokens=1200,
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.1,
eos_token_id=128258,
)
# Convert to waveform
audio = parse_output(gen_ids)
# PCM16 conversion & chunking
pcm16 = (audio * 32767).astype('int16').tobytes()
# 0.1 s @24 kHz = 2400 samples = 4800 bytes
chunk_size = 2400 * 2
for i in range(0, len(pcm16), chunk_size):
await ws.send_bytes(pcm16[i:i+chunk_size])
await asyncio.sleep(0.1) # pacing
except WebSocketDisconnect:
print("Client disconnected")
except Exception as e:
print("Error in /ws/tts:", e)
await ws.close(code=1011)
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|