Spaces:
Paused
Paused
File size: 4,913 Bytes
a09ea48 0316ec3 4189fe1 0316ec3 a09ea48 0dfc310 0316ec3 0dfc310 a09ea48 2008a3f 0dfc310 1b79dec 0316ec3 0dfc310 0316ec3 a09ea48 0316ec3 6214f63 0dfc310 6214f63 0dfc310 6214f63 a09ea48 0316ec3 0dfc310 a09ea48 0316ec3 a09ea48 0dfc310 a09ea48 ad94d02 a09ea48 0dfc310 0316ec3 0dfc310 a09ea48 0dfc310 0316ec3 a09ea48 0dfc310 a09ea48 0dfc310 a09ea48 0dfc310 a09ea48 0dfc310 0316ec3 0dfc310 0316ec3 0dfc310 97006e1 0dfc310 4189fe1 a8606ac a09ea48 4189fe1 a09ea48 0dfc310 a09ea48 0dfc310 a09ea48 0dfc310 4189fe1 a09ea48 4189fe1 a09ea48 4189fe1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login, snapshot_download
# — ENV & HF‑AUTH —
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
# — Debug: CPU‑Modus zum Entwickeln, später wieder "cuda" —
#device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
# — Modelle laden —
print("Loading SNAC model...")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model_name = "canopylabs/3b-de-pretrain-research_release"
# optional: explizites snapshot_download (entfernt große Dateien)
'''
snapshot_download(
repo_id=model_name,
allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
ignore_patterns=[
"optimizer.pt", "pytorch_model.bin", "training_args.bin",
"scheduler.pt", "tokenizer.json", "tokenizer_config.json",
"special_tokens_map.json", "vocab.json", "merges.txt", "tokenizer.*"
]
)
'''
print("Loading Orpheus model...")
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16
).to(device)
model.config.pad_token_id = model.config.eos_token_id
tokenizer = AutoTokenizer.from_pretrained(model_name)
# — Hilfsfunktionen —
def process_prompt(text: str, voice: str):
prompt = f"{voice}: {text}"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
start = torch.tensor([[128259]], dtype=torch.int64)
end = torch.tensor([[128009, 128260]], dtype=torch.int64)
ids = torch.cat([start, input_ids, end], dim=1).to(device)
mask = torch.ones_like(ids).to(device)
return ids, mask
def parse_output(generated_ids: torch.LongTensor):
"""Extrahiere rohe Tokenliste nach dem letzten 128257-Start-Token."""
token_to_find = 128257
token_to_remove = 128258
# 1) Finde letztes Start-Token, croppe
idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
if idxs.numel() > 0:
cut = idxs[-1].item() + 1
cropped = generated_ids[:, cut:]
else:
cropped = generated_ids
# 2) Entferne Padding-Markierungen
rows = []
for row in cropped:
rows.append(row[row != token_to_remove])
# 3) Flache Liste zurückgeben
return rows[0].tolist()
def redistribute_codes(code_list: list[int], snac_model: SNAC):
"""Verteile die Codes auf drei Layer, dekodiere in Audio."""
layer1, layer2, layer3 = [], [], []
for i in range((len(code_list) + 1) // 7):
base = code_list[7*i : 7*i+7]
layer1.append(base[0])
layer2.append(base[1] - 4096)
layer3.append(base[2] - 2*4096)
layer3.append(base[3] - 3*4096)
layer2.append(base[4] - 4*4096)
layer3.append(base[5] - 5*4096)
layer3.append(base[6] - 6*4096)
dev = next(snac_model.parameters()).device
codes = [
torch.tensor(layer1, device=dev).unsqueeze(0),
torch.tensor(layer2, device=dev).unsqueeze(0),
torch.tensor(layer3, device=dev).unsqueeze(0),
]
audio = snac_model.decode(codes)
return audio.detach().squeeze().cpu().numpy() # float32 @24 kHz
# — FastAPI + WebSocket-Endpoint —
app = FastAPI()
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
while True:
msg = await ws.receive_text()
data = json.loads(msg)
text = data.get("text", "")
voice = data.get("voice", "jana")
# 1) Prompt → Tokens
ids, mask = process_prompt(text, voice)
# 2) Token-Generation (erst klein testen!)
gen_ids = model.generate(
input_ids=ids,
attention_mask=mask,
max_new_tokens=200, # zum Debuggen klein halten
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.1,
eos_token_id=128258,
)
# 3) Tokens → Code-Liste → Audio
code_list = parse_output(gen_ids)
audio_np = redistribute_codes(code_list, snac)
# 4) PCM16-Bytes und Stream in 0.1s-Chunks
pcm16 = (audio_np * 32767).astype("int16").tobytes()
chunk = 2400 * 2 # 2400 samples @24kHz → 0.1s * 2 bytes
for i in range(0, len(pcm16), chunk):
await ws.send_bytes(pcm16[i : i+chunk])
await asyncio.sleep(0.1)
except WebSocketDisconnect:
print("Client disconnected")
except Exception as e:
print("Error in /ws/tts:", e)
await ws.close(code=1011)
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|