Tomtom84's picture
Update app.py
5031731 verified
raw
history blame
7.04 kB
# app.py ─────────────────────────────────────────────────────────────
import os, json, asyncio, torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import (AutoTokenizer, AutoModelForCausalLM, LogitsProcessor)
from transformers.generation.utils import Cache
from snac import SNAC
# ── 0. HF‑Login & Device ─────────────────────────────────────────────
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
#Β Flash‑Attention‑Bug in PyTorchΒ 2.2.x umgehen
torch.backends.cuda.enable_flash_sdp(False)
# ── 1. Konstanten ────────────────────────────────────────────────────
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
CHUNK_TOKENS = 50 # pro mini‑generate
START_TOKEN = 128259
NEW_BLOCK_TOKEN = 128257
EOS_TOKEN = 128258
AUDIO_BASE = 128266 # erster Audio‑Code
VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)
# ── 2. Dynamischer Logit‑Masker ──────────────────────────────────────
class DynamicAudioMask(LogitsProcessor):
"""
blockt EOS, bis mindestens `min_audio_blocks` gesendet wurden
"""
def __init__(self, audio_ids: torch.Tensor, min_audio_blocks: int = 1):
super().__init__()
self.audio_ids = audio_ids
self.ctrl_ids = torch.tensor([NEW_BLOCK_TOKEN], device=audio_ids.device)
self.min_blocks = min_audio_blocks
self.blocks_done = 0
def __call__(self, input_ids, scores):
allowed = torch.cat([self.audio_ids, self.ctrl_ids])
if self.blocks_done >= self.min_blocks: # jetzt darf EOS dazu
allowed = torch.cat([allowed, torch.tensor([EOS_TOKEN], device=scores.device)])
mask = torch.full_like(scores, float("-inf"))
mask[:, allowed] = 0
return scores + mask
# ── 3. FastAPI GrundgerΓΌst ───────────────────────────────────────────
app = FastAPI()
@app.get("/")
async def ping():
return {"msg": "Orpheus‑TTS up & running"}
@app.on_event("startup")
async def load_models():
global tok, model, snac, masker
print("⏳ Lade Modelle …")
tok = AutoTokenizer.from_pretrained(REPO)
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model = AutoModelForCausalLM.from_pretrained(
REPO,
low_cpu_mem_usage=True,
device_map={"": 0} if device == "cuda" else None,
torch_dtype=torch.bfloat16 if device == "cuda" else None,
)
model.config.pad_token_id = model.config.eos_token_id
model.config.use_cache = True
masker = DynamicAudioMask(VALID_AUDIO_IDS.to(device))
print("βœ…Β Modelle geladen")
# ── 4. Hilfs‑Funktionen ──────────────────────────────────────────────
def build_inputs(text: str, voice: str):
prompt = f"{voice}: {text}"
ids = tok(prompt, return_tensors="pt").input_ids.to(device)
ids = torch.cat(
[ torch.tensor([[START_TOKEN]], device=device),
ids,
torch.tensor([[128009, 128260]], device=device) ],
dim=1,
)
attn = torch.ones_like(ids)
return ids, attn
def decode_block(block7: list[int]) -> bytes:
l1, l2, l3 = [], [], []
b = block7
l1.append(b[0])
l2.append(b[1] - 4096)
l3.extend([b[2] - 8192, b[3] - 12288])
l2.append(b[4] - 16384)
l3.extend([b[5] - 20480, b[6] - 24576])
codes = [
torch.tensor(l1, device=device).unsqueeze(0),
torch.tensor(l2, device=device).unsqueeze(0),
torch.tensor(l3, device=device).unsqueeze(0),
]
audio = snac.decode(codes).squeeze().cpu().numpy()
return (audio * 32767).astype("int16").tobytes()
# ── 5. WebSocket‑TTS‑Endpoint ───────────────────────────────────────
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
await ws.accept()
try:
req = json.loads(await ws.receive_text())
text = req.get("text", "")
voice = req.get("voice", "Jakob")
ids, attn = build_inputs(text, voice)
past = None
buf = []
while True:
out = model.generate(
input_ids = ids if past is None else None,
attention_mask = attn if past is None else None,
past_key_values = past,
max_new_tokens = CHUNK_TOKENS,
logits_processor= [masker], # β–Ί dynamischer Masker
do_sample=True, temperature=0.7, top_p=0.95,
return_dict_in_generate=True,
use_cache=True,
)
# Cache & neue Tokens extrahieren --------------------------------
pkv = out.past_key_values
if isinstance(pkv, Cache): # HFΒ >=Β 4.47
pkv = pkv.to_legacy()
past = pkv
new = out.sequences[0, -out.num_generated_tokens :].tolist()
print("new tokens:", new[:20]) # Debug‑Ausgabe
# ----------------------------------------------------------------
for t in new:
if t == EOS_TOKEN:
raise StopIteration
if t == NEW_BLOCK_TOKEN:
buf.clear()
continue
buf.append(t - AUDIO_BASE)
if len(buf) == 7:
await ws.send_bytes(decode_block(buf))
buf.clear()
masker.blocks_done += 1 # β–ΊΒ jetzt darf ggf. EOS
# nΓ€chsten generate‑Step nur noch mit Cache, keine neuen ids
ids, attn = None, None
except (StopIteration, WebSocketDisconnect):
pass
except Exception as e:
print("❌ WS‑Error:", e)
if ws.client_state.name != "DISCONNECTED":
await ws.close(code=1011)
finally:
if ws.client_state.name != "DISCONNECTED":
try:
await ws.close()
except RuntimeError:
pass # Close‑Frame war schon raus
# ── 6. Lokaler Start (uvicorn) ───────────────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)