Spaces:
Paused
Paused
File size: 4,418 Bytes
2c15189 0316ec3 d9ea17d 4189fe1 9bf14d0 a09ea48 d9ea17d 0316ec3 d9ea17d 2c15189 9bf14d0 2008a3f d9ea17d 1ab029d 0316ec3 9bf14d0 0dfc310 9bf14d0 d9ea17d 9bf14d0 d9ea17d 9bf14d0 d9ea17d 9bf14d0 d9ea17d 9bf14d0 d9ea17d 9bf14d0 3281189 d9ea17d 9bf14d0 d9ea17d 9bf14d0 d9ea17d 9bf14d0 d9ea17d a8606ac 2c15189 a09ea48 4189fe1 d9ea17d 9bf14d0 d9ea17d c70d8eb fd06e70 4189fe1 fd06e70 a09ea48 2c15189 a09ea48 d9ea17d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
import json
import torch
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer
from snac import SNAC
# — HF‑Token & Login (wenn gesetzt) —
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
# — Device wählen —
device = "cuda" if torch.cuda.is_available() else "cpu"
app = FastAPI()
@app.get("/")
async def read_root():
return {"message": "Hello, world!"}
# — Globale Modelle —
model = None
tokenizer = None
snac_model = None
# — Startup: SNAC & Orpheus laden —
@app.on_event("startup")
async def load_models():
global model, tokenizer, snac_model
# 1) SNAC
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
# 2) Orpheus‑TTS
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_synthetic-v0.1"
tokenizer = AutoTokenizer.from_pretrained(REPO)
model = AutoModelForCausalLM.from_pretrained(
REPO,
device_map="auto" if device=="cuda" else None,
torch_dtype=torch.bfloat16 if device=="cuda" else None,
low_cpu_mem_usage=True
).to(device)
model.config.pad_token_id = model.config.eos_token_id
# — Marker und Offsets aus der Vorlage —
START_TOKEN = 128259
END_TOKENS = [128009, 128260]
AUDIO_OFFSET = 128266
def process_single_prompt(prompt: str, voice: str) -> list[int]:
# Prompt zusammenbauen
if voice and voice != "in_prompt":
text = f"{voice}: {prompt}"
else:
text = prompt
# Tokenize + Marker
ids = tokenizer(text, return_tensors="pt").input_ids
start = torch.tensor([[START_TOKEN]], dtype=torch.int64)
end = torch.tensor([END_TOKENS], dtype=torch.int64)
input_ids = torch.cat([start, ids, end], dim=1).to(device)
attention_mask = torch.ones_like(input_ids)
# Generieren
gen = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=4000,
do_sample=True,
temperature=0.6,
top_p=0.95,
repetition_penalty=1.1,
eos_token_id=128258,
use_cache=True,
)
# letzten START_TOKEN finden & croppen
token_to_find = 128257
token_to_remove = 128258
idxs = (gen == token_to_find).nonzero(as_tuple=True)[1]
if idxs.numel() > 0:
cropped = gen[:, idxs[-1] + 1 :]
else:
cropped = gen
# Padding entfernen
row = cropped[0][cropped[0] != token_to_remove]
# Aus Länge ein Vielfaches von 7 machen
new_len = (row.size(0) // 7) * 7
trimmed = row[:new_len].tolist()
# Offset abziehen
return [t - AUDIO_OFFSET for t in trimmed]
def redistribute_codes(code_list: list[int]) -> np.ndarray:
# Die 7er‑Blöcke auf 3 Layer verteilen und dekodieren
layer1, layer2, layer3 = [], [], []
for i in range(len(code_list) // 7):
b = code_list[7*i : 7*i+7]
layer1.append(b[0])
layer2.append(b[1] - 4096)
layer3.append(b[2] - 2*4096)
layer3.append(b[3] - 3*4096)
layer2.append(b[4] - 4*4096)
layer3.append(b[5] - 5*4096)
layer3.append(b[6] - 6*4096)
codes = [
torch.tensor(layer1, device=device).unsqueeze(0),
torch.tensor(layer2, device=device).unsqueeze(0),
torch.tensor(layer3, device=device).unsqueeze(0),
]
audio = snac_model.decode(codes).squeeze().cpu().numpy()
return audio # float32 @24 kHz
# — WebSocket‑Endpoint für TTS —
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
# 1) Text + Voice empfangen
msg = await ws.receive_text()
req = json.loads(msg)
text = req.get("text", "")
voice = req.get("voice", "")
# 2) Prompt → Code‑Liste
with torch.no_grad():
codes = process_single_prompt(text, voice)
audio_np = redistribute_codes(codes)
# 3) In PCM16 konvertieren und senden
pcm16 = (audio_np * 32767).astype(np.int16).tobytes()
await ws.send_bytes(pcm16)
# 4) sauber schließen
await ws.close()
except WebSocketDisconnect:
pass
except Exception as e:
print("Error in /ws/tts:", e)
await ws.close(code=1011)
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|