Tomtom84's picture
Update app.py
bca75ea verified
raw
history blame
5.63 kB
# app.py ─────────────────────────────────────────────────────────────
import os, json, asyncio, torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor
from snac import SNAC
# ── 0.Β HF‑Auth & Device ──────────────────────────────────────────────
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Flash‑Attention‑Bug in PyTorchΒ 2.2.x
torch.backends.cuda.enable_flash_sdp(False)
# ── 1.Β Konstanten ────────────────────────────────────────────────────
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_synthetic-v0.1"
CHUNK_TOKENS = 50
START_TOKEN = 128259
NEW_BLOCK_TOKEN = 128257
EOS_TOKEN = 128258
AUDIO_BASE = 128266
VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)
# ── 2.Β Logit‑Processor zum Maskieren ────────────────────────────────
class AudioLogitMask(LogitsProcessor):
def __init__(self, allowed_ids: torch.Tensor):
super().__init__()
self.allowed = allowed_ids
def __call__(self, input_ids, scores):
# scores shape: [batch, vocab]
mask = torch.full_like(scores, float("-inf"))
mask[:, self.allowed] = 0
return scores + mask
ALLOWED_IDS = torch.cat(
[VALID_AUDIO_IDS, torch.tensor([NEW_BLOCK_TOKEN, EOS_TOKEN])]
).to(device)
MASKER = AudioLogitMask(ALLOWED_IDS)
# ── 3.Β FastAPI ‑ GrundgerΓΌst ─────────────────────────────────────────
app = FastAPI()
@app.get("/")
async def ping():
return {"msg": "Orpheus‑TTS OK"}
@app.on_event("startup")
async def load_models():
global tok, model, snac
tok = AutoTokenizer.from_pretrained(REPO)
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model = AutoModelForCausalLM.from_pretrained(
REPO,
low_cpu_mem_usage=True,
device_map={"": 0} if device == "cuda" else None,
torch_dtype=torch.bfloat16 if device == "cuda" else None,
)
model.config.pad_token_id = model.config.eos_token_id
model.config.use_cache = True
# ── 4.Β Hilfs‑Funktionen ─────────────────────────────────────────────
def build_prompt(text:str, voice:str):
base = f"{voice}: {text}"
ids = tok(base, return_tensors="pt").input_ids.to(device)
ids = torch.cat(
[
torch.tensor([[START_TOKEN]], device=device),
ids,
torch.tensor([[128009, 128260]], device=device),
],
1,
)
return ids, torch.ones_like(ids)
def decode_snac(block7:list[int])->bytes:
l1,l2,l3=[],[],[]
b=block7
l1.append(b[0])
l2.append(b[1]-4096)
l3.extend([b[2]-8192, b[3]-12288])
l2.append(b[4]-16384)
l3.extend([b[5]-20480, b[6]-24576])
codes=[torch.tensor(x,device=device).unsqueeze(0)
for x in (l1,l2,l3)]
audio=snac.decode(codes).squeeze().cpu().numpy()
return (audio*32767).astype("int16").tobytes()
# ── 5.Β WebSocket‑Endpoint ───────────────────────────────────────────
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
await ws.accept()
try:
req = json.loads(await ws.receive_text())
text = req.get("text","")
voice = req.get("voice","Jakob")
ids, attn = build_prompt(text, voice)
past = None
buf = []
while True:
out = model.generate(
input_ids=ids if past is None else None,
attention_mask=attn if past is None else None,
past_key_values=past,
max_new_tokens=CHUNK_TOKENS,
logits_processor=[MASKER],
do_sample=True, temperature=0.7, top_p=0.95,
use_cache=True,
return_dict_in_generate=True,
)
past = out.past_key_values
newtok = out.sequences[0,-out.num_generated_tokens:].tolist()
for t in newtok:
if t==EOS_TOKEN:
raise StopIteration
if t==NEW_BLOCK_TOKEN:
buf.clear(); continue
buf.append(t-AUDIO_BASE)
if len(buf)==7:
await ws.send_bytes(decode_snac(buf))
buf.clear()
# ab jetzt nur noch mit Cache weiter‑generieren
ids, attn = None, None
except (StopIteration, WebSocketDisconnect):
pass
except Exception as e:
print("WS‑Error:", e)
await ws.close(code=1011)
finally:
if ws.client_state.name!="DISCONNECTED":
await ws.close()
# ── 6.Β Lokaler Test ─────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)