Spaces:
Paused
Paused
File size: 5,667 Bytes
2c15189 a3af518 0316ec3 a4cfefc 4189fe1 9bf14d0 d9ea17d a3af518 0316ec3 a3af518 d9ea17d 2c15189 a4cfefc 2008a3f a4cfefc 1ab029d 0316ec3 a4cfefc 9bf14d0 0dfc310 a4cfefc 9bf14d0 a4cfefc 9bf14d0 a4cfefc 9bf14d0 d9ea17d a3af518 a4cfefc a3af518 f63f843 a4cfefc f63f843 9bf14d0 f63f843 a4cfefc 986d4cd a4cfefc f63f843 a3af518 a4cfefc f63f843 a3af518 f63f843 a4cfefc f63f843 a4cfefc 9bf14d0 a4cfefc 9bf14d0 a3af518 a4cfefc 9bf14d0 a4cfefc a8606ac 2c15189 a09ea48 4189fe1 f63f843 a4cfefc f63f843 a4cfefc f63f843 a4cfefc f63f843 a4cfefc f63f843 a4cfefc f63f843 a4cfefc f63f843 a4cfefc f63f843 a4cfefc f63f843 d9ea17d a4cfefc f63f843 a4cfefc a3af518 a4cfefc c70d8eb 4189fe1 a4cfefc a09ea48 a4cfefc a09ea48 a4cfefc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import os
import json
import asyncio
import torch
# Bugfix für PyTorch 2.2.x Flash‑SDP‑Assertion
torch.backends.cuda.enable_flash_sdp(False)
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
# — HF‑Token & Login —
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
# — Device wählen —
device = "cuda" if torch.cuda.is_available() else "cpu"
# — FastAPI instanzieren —
app = FastAPI()
# — Hello‑Route, damit GET / kein 404 mehr gibt —
@app.get("/")
async def read_root():
return {"message": "Orpheus TTS WebSocket Server läuft"}
# — Modelle beim Startup laden —
@app.on_event("startup")
async def load_models():
global tokenizer, model, snac
# SNAC für Audio‑Decoding
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
# Orpheus‑TTS Base
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
tokenizer = AutoTokenizer.from_pretrained(REPO)
model = AutoModelForCausalLM.from_pretrained(
REPO,
device_map={"": 0} if device=="cuda" else None,
torch_dtype=torch.bfloat16 if device=="cuda" else None,
low_cpu_mem_usage=True,
return_legacy_cache=True # für compatibility mit past_key_values als Tuple
).to(device)
model.config.pad_token_id = model.config.eos_token_id
# --- Logit‑Masking vorbereiten ---
# reine Audio‑Tokens laufen von 128266 bis 128266+4096-1
AUDIO_OFFSET = 128266
AUDIO_COUNT = 4096
valid_audio = torch.arange(AUDIO_OFFSET, AUDIO_OFFSET + AUDIO_COUNT, device=device)
ctrl_tokens = torch.tensor([128257, model.config.eos_token_id], device=device)
global ALLOWED_IDS
ALLOWED_IDS = torch.cat([valid_audio, ctrl_tokens])
def sample_from_logits(logits: torch.Tensor) -> int:
"""
Maskt alle IDs außer ALLOWED_IDS und sampelt dann einen Token.
"""
# logits: [1, vocab_size]
mask = torch.full_like(logits, float("-inf"))
mask[0, ALLOWED_IDS] = 0.0
probs = torch.softmax(logits + mask, dim=-1)
return torch.multinomial(probs, num_samples=1).item()
def prepare_inputs(text: str, voice: str):
prompt = f"{voice}: {text}"
ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Start‐/End‐Marker
start = torch.tensor([[128259]], dtype=torch.int64, device=device)
end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device)
input_ids = torch.cat([start, ids, end], dim=1)
attention_mask = torch.ones_like(input_ids, device=device)
return input_ids, attention_mask
def decode_block(block: list[int]) -> bytes:
"""
Aus 7 gesampelten Audio‑Codes einen PCM‑16‑Byte‐Block dekodieren.
Hier erwarten wir block[i] = raw_token - 128266.
"""
layer1, layer2, layer3 = [], [], []
b = block
layer1.append(b[0])
layer2.append(b[1] - 4096)
layer3.append(b[2] - 2*4096)
layer3.append(b[3] - 3*4096)
layer2.append(b[4] - 4*4096)
layer3.append(b[5] - 5*4096)
layer3.append(b[6] - 6*4096)
dev = next(snac.parameters()).device
codes = [
torch.tensor(layer1, device=dev).unsqueeze(0),
torch.tensor(layer2, device=dev).unsqueeze(0),
torch.tensor(layer3, device=dev).unsqueeze(0),
]
audio = snac.decode(codes).squeeze().cpu().numpy()
# in PCM16 umwandeln
pcm16 = (audio * 32767).astype("int16").tobytes()
return pcm16
# — WebSocket Endpoint für TTS Streaming —
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
await ws.accept()
try:
msg = await ws.receive_text()
req = json.loads(msg)
text = req.get("text", "")
voice = req.get("voice", "Jakob")
# Inputs vorbereiten
input_ids, attention_mask = prepare_inputs(text, voice)
past_kvs = None
buffer = [] # sammelt die 7 Audio‑Codes
# Token‑für‑Token Loop
while True:
out = model(
input_ids=input_ids if past_kvs is None else None,
attention_mask=attention_mask if past_kvs is None else None,
past_key_values=past_kvs,
use_cache=True,
return_dict=True
)
past_kvs = out.past_key_values
next_tok = sample_from_logits(out.logits[:, -1, :])
# Ende?
if next_tok == model.config.eos_token_id:
break
# Reset bei neuem Audio‑Block‑Start
if next_tok == 128257:
buffer.clear()
input_ids = torch.tensor([[next_tok]], device=device)
attention_mask = torch.ones_like(input_ids)
continue
# Audio‑Code sammeln (Offset abziehen)
buffer.append(next_tok - 128266)
# sobald wir 7 Codes haben → dekodieren & senden
if len(buffer) == 7:
pcm = decode_block(buffer)
buffer.clear()
await ws.send_bytes(pcm)
# nächster Schritt: genau diesen Token wieder einspeisen
input_ids = torch.tensor([[next_tok]], device=device)
attention_mask = torch.ones_like(input_ids)
# sauber beenden
await ws.close()
except WebSocketDisconnect:
pass
except Exception as e:
print("Error in /ws/tts:", e)
await ws.close(code=1011)
# — CLI zum lokalen Testen —
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|