Tomtom84's picture
Update app.py
1bad3fb verified
raw
history blame
6.24 kB
# app.py ──────────────────────────────────────────────────────────────
import os, json, asyncio, torch, logging
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
from transformers.generation.utils import Cache
from snac import SNAC
# ── 0. Auth & Device ────────────────────────────────────────────────
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.enable_flash_sdp(False) # Flash‑Bug umgehen
logging.getLogger("transformers.generation.utils").setLevel("ERROR")
# ── 1. Konstanten ───────────────────────────────────────────────────
MODEL_REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
CHUNK_TOKENS = 50
START_TOKEN = 128259 # <𝑠>
NEW_BLOCK_TOKEN = 128257 # πŸ”Šβ€‘Start
EOS_TOKEN = 128258 # <eos>
PROMPT_END = [128009, 128260]
AUDIO_BASE = 128266
VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)
# ── 2. Logit‑Masker ─────────────────────────────────────────────────
class AudioMask(LogitsProcessor):
def __init__(self, allowed: torch.Tensor):
super().__init__()
self.allowed = allowed
def __call__(self, input_ids, scores):
mask = torch.full_like(scores, float("-inf"))
mask[:, self.allowed] = 0
return scores + mask
ALLOWED_IDS = torch.cat([
VALID_AUDIO_IDS,
torch.tensor([START_TOKEN, NEW_BLOCK_TOKEN, EOS_TOKEN])
]).to(device)
MASKER = AudioMask(ALLOWED_IDS)
# ── 3. FastAPI GrundgerΓΌst ──────────────────────────────────────────
app = FastAPI()
@app.get("/")
async def ping():
return {"message": "Orpheus‑TTSΒ ready"}
@app.on_event("startup")
async def load_models():
global tok, model, snac
tok = AutoTokenizer.from_pretrained(MODEL_REPO)
model = AutoModelForCausalLM.from_pretrained(
MODEL_REPO,
low_cpu_mem_usage=True,
device_map={"": 0} if device == "cuda" else None,
torch_dtype=torch.bfloat16 if device == "cuda" else None,
)
model.config.pad_token_id = model.config.eos_token_id
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
# ── 4. Hilfsfunktionen ──────────────────────────────────────────────
def build_inputs(text: str, voice: str):
prompt = f"{voice}: {text}" if voice and voice != "in_prompt" else text
ids = tok(prompt, return_tensors="pt").input_ids.to(device)
ids = torch.cat([
torch.tensor([[START_TOKEN]], device=device),
ids,
torch.tensor([PROMPT_END], device=device)
], 1)
mask = torch.ones_like(ids)
return ids, mask # shape (1,Β L)
def decode_block(block7: list[int]) -> bytes:
l1, l2, l3 = [], [], []
b = block7
l1.append(b[0])
l2.append(b[1] - 4096)
l3 += [b[2]-8192, b[3]-12288]
l2.append(b[4] - 16384)
l3 += [b[5]-20480, b[6]-24576]
codes = [torch.tensor(x, device=device).unsqueeze(0) for x in (l1, l2, l3)]
audio = snac.decode(codes).squeeze().cpu().numpy()
return (audio * 32767).astype("int16").tobytes()
# ── 5. WebSocket‑Endpoint ───────────────────────────────────────────
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
await ws.accept()
try:
req = json.loads(await ws.receive_text())
ids, attn = build_inputs(req.get("text", ""), req.get("voice", "Jakob"))
prompt_len = ids.size(1)
past, buf = None, []
while True:
out = model.generate(
input_ids=ids if past is None else None,
attention_mask=attn if past is None else None,
past_key_values=past,
max_new_tokens=CHUNK_TOKENS,
logits_processor=[MASKER],
do_sample=True, top_p=0.95, temperature=0.7,
return_dict_in_generate=True,
use_cache=True,
return_legacy_cache=True, # β‡  Warnung verschwindet
)
past = out.past_key_values # unverΓ€ndert weiterreichen
seq = out.sequences[0].tolist()
new = seq[prompt_len:]; prompt_len = len(seq)
if not new: # selten, aber mΓΆglich
continue
for t in new:
if t == EOS_TOKEN:
await ws.close()
return
if t == NEW_BLOCK_TOKEN:
buf.clear(); continue
if t < AUDIO_BASE: # sollte durch Maske nie passieren
continue
buf.append(t - AUDIO_BASE)
if len(buf) == 7:
await ws.send_bytes(decode_block(buf))
buf.clear()
# Ab jetzt nur noch Cache – IDs & Mask nicht mehr nΓΆtig
ids = attn = None
except WebSocketDisconnect:
pass
except Exception as e:
print("WS‑Error:", e)
if ws.client_state.name == "CONNECTED":
await ws.close(code=1011)
# ── 6. Lokaler Start ────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn, sys
port = int(sys.argv[1]) if len(sys.argv) > 1 else 7860
uvicorn.run("app:app", host="0.0.0.0", port=port)