Spaces:
Paused
Paused
File size: 5,137 Bytes
2c15189 67c3132 0316ec3 4189fe1 0316ec3 a09ea48 0dfc310 0316ec3 67c3132 a09ea48 2c15189 2008a3f 67c3132 1ab029d 0316ec3 67c3132 a09ea48 0316ec3 674acbf 67c3132 0dfc310 2c15189 67c3132 2c15189 0dfc310 f001a32 67c3132 d408dd5 9cd424e 67c3132 a09ea48 9cd424e b3e4aa7 0dfc310 67c3132 9cd424e a09ea48 67c3132 a09ea48 67c3132 9cd424e 2c15189 67c3132 2c15189 67c3132 2c15189 67c3132 9cd424e 2c15189 67c3132 2c15189 67c3132 2c15189 67c3132 2c15189 67c3132 2c15189 67c3132 97006e1 4189fe1 d408dd5 67c3132 d408dd5 a8606ac 2c15189 a09ea48 4189fe1 2c15189 67c3132 2c15189 67c3132 2c15189 67c3132 2c15189 67c3132 2c15189 67c3132 2c15189 67c3132 2c15189 67c3132 2c15189 67c3132 4189fe1 2c15189 a09ea48 2c15189 a09ea48 f3890ef 2c15189 f3890ef 67c3132 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import os
import json
import asyncio
import numpy as np
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login, snapshot_download
# — ENV & HF‑AUTH —
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
# — Device —
device = "cuda" if torch.cuda.is_available() else "cpu"
# — Modelle laden —
print("Loading SNAC model...")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
print("Downloading model weights (config + safetensors)...")
snapshot_download(
repo_id=model_name,
allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
ignore_patterns=[
"optimizer.pt", "pytorch_model.bin", "training_args.bin",
"scheduler.pt", "tokenizer.json", "tokenizer_config.json",
"special_tokens_map.json", "vocab.json", "merges.txt", "tokenizer.*"
]
)
print("Loading Orpheus model...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16
).to(device)
model.config.pad_token_id = model.config.eos_token_id
tokenizer = AutoTokenizer.from_pretrained(model_name)
# — Hilfsfunktionen —
def process_prompt(text: str, voice: str):
"""Bereitet input_ids und attention_mask für das Modell vor."""
prompt = f"{voice}: {text}"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
start = torch.tensor([[128259]], dtype=torch.int64)
end = torch.tensor([[128009, 128260]], dtype=torch.int64)
ids = torch.cat([start, input_ids, end], dim=1).to(device)
mask = torch.ones_like(ids).to(device)
return ids, mask
def parse_output(generated_ids: torch.LongTensor):
"""Extrahiere rohe Tokenliste nach dem letzten 128257-Start-Token."""
token_to_find = 128257
token_to_remove = 128258
idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
if idxs.numel() > 0:
cut = idxs[-1].item() + 1
cropped = generated_ids[:, cut:]
else:
cropped = generated_ids
# Entferne EOS‑Token
row = cropped[0]
return row[row != token_to_remove].tolist()
def redistribute_codes(code_list: list[int], snac_model: SNAC):
"""
Verteilt die Token nur in kompletten 7er‑Blöcken auf die drei SNAC‑Layer
und dekodiert in Audio. Unvollständige Reste (<7 Tokens) werden verworfen.
"""
n_blocks = len(code_list) // 7
layer1, layer2, layer3 = [], [], []
for i in range(n_blocks):
base = code_list[7*i : 7*i + 7]
layer1.append(base[0])
layer2.append(base[1] - 4096)
layer3.append(base[2] - 2*4096)
layer3.append(base[3] - 3*4096)
layer2.append(base[4] - 4*4096)
layer3.append(base[5] - 5*4096)
layer3.append(base[6] - 6*4096)
if not layer1:
# kein kompletter Block → leeres Audio
return np.zeros(0, dtype=np.float32)
dev = next(snac_model.parameters()).device
codes = [
torch.tensor(layer1, device=dev).unsqueeze(0),
torch.tensor(layer2, device=dev).unsqueeze(0),
torch.tensor(layer3, device=dev).unsqueeze(0),
]
audio = snac_model.decode(codes)
return audio.detach().squeeze().cpu().numpy()
# — FastAPI Setup —
app = FastAPI()
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
while True:
# Erwartet JSON: {"text": "...", "voice": "Jakob"}
data = json.loads(await ws.receive_text())
text = data.get("text", "")
voice = data.get("voice", "Jakob")
# 1) Tokens vorbereiten
ids, mask = process_prompt(text, voice)
# 2) Generierung
gen_ids = model.generate(
input_ids=ids,
attention_mask=mask,
max_new_tokens=2000, # hier nach Bedarf anpassen
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.1,
eos_token_id=model.config.eos_token_id,
)
# 3) Tokens → Code-Liste → Audio
codes = parse_output(gen_ids)
audio_np = redistribute_codes(codes, snac)
# 4) in 0.1s‑Stücken PCM16 streamen
pcm16 = (audio_np * 32767).astype("int16").tobytes()
chunk = 2400 * 2 # 2400 samples @24kHz = 0.1s * 2 bytes
for i in range(0, len(pcm16), chunk):
await ws.send_bytes(pcm16[i : i+chunk])
await asyncio.sleep(0.1)
# Ende der while‐Schleife
except WebSocketDisconnect:
print("Client disconnected")
except Exception as e:
print("Error in /ws/tts:", e)
await ws.close(code=1011)
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|