Spaces:
Paused
Paused
# app.py βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
import os, json, asyncio, torch | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from huggingface_hub import login | |
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor | |
from snac import SNAC | |
# ββ 0.Β HFβAuth & Device ββββββββββββββββββββββββββββββββββββββββββββββ | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if HF_TOKEN: | |
login(HF_TOKEN) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# FlashβAttentionβBug in PyTorchΒ 2.2.x | |
torch.backends.cuda.enable_flash_sdp(False) | |
# ββ 1.Β Konstanten ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_synthetic-v0.1" | |
CHUNK_TOKENS = 50 | |
START_TOKEN = 128259 | |
NEW_BLOCK_TOKEN = 128257 | |
EOS_TOKEN = 128258 | |
AUDIO_BASE = 128266 | |
VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096) | |
# ββ 2.Β LogitβProcessor zum Maskieren ββββββββββββββββββββββββββββββββ | |
class AudioLogitMask(LogitsProcessor): | |
def __init__(self, allowed_ids: torch.Tensor): | |
super().__init__() | |
self.allowed = allowed_ids | |
def __call__(self, input_ids, scores): | |
# scores shape: [batch, vocab] | |
mask = torch.full_like(scores, float("-inf")) | |
mask[:, self.allowed] = 0 | |
return scores + mask | |
ALLOWED_IDS = torch.cat( | |
[VALID_AUDIO_IDS, torch.tensor([NEW_BLOCK_TOKEN, EOS_TOKEN])] | |
).to(device) | |
MASKER = AudioLogitMask(ALLOWED_IDS) | |
# ββ 3.Β FastAPI β GrundgerΓΌst βββββββββββββββββββββββββββββββββββββββββ | |
app = FastAPI() | |
async def ping(): | |
return {"msg": "OrpheusβTTS OK"} | |
async def load_models(): | |
global tok, model, snac | |
tok = AutoTokenizer.from_pretrained(REPO) | |
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) | |
model = AutoModelForCausalLM.from_pretrained( | |
REPO, | |
low_cpu_mem_usage=True, | |
device_map={"": 0} if device == "cuda" else None, | |
torch_dtype=torch.bfloat16 if device == "cuda" else None, | |
) | |
model.config.pad_token_id = model.config.eos_token_id | |
model.config.use_cache = True | |
# ββ 4.Β HilfsβFunktionen βββββββββββββββββββββββββββββββββββββββββββββ | |
def build_prompt(text:str, voice:str): | |
base = f"{voice}: {text}" | |
ids = tok(base, return_tensors="pt").input_ids.to(device) | |
ids = torch.cat( | |
[ | |
torch.tensor([[START_TOKEN]], device=device), | |
ids, | |
torch.tensor([[128009, 128260]], device=device), | |
], | |
1, | |
) | |
return ids, torch.ones_like(ids) | |
def decode_snac(block7:list[int])->bytes: | |
l1,l2,l3=[],[],[] | |
b=block7 | |
l1.append(b[0]) | |
l2.append(b[1]-4096) | |
l3.extend([b[2]-8192, b[3]-12288]) | |
l2.append(b[4]-16384) | |
l3.extend([b[5]-20480, b[6]-24576]) | |
codes=[torch.tensor(x,device=device).unsqueeze(0) | |
for x in (l1,l2,l3)] | |
audio=snac.decode(codes).squeeze().cpu().numpy() | |
return (audio*32767).astype("int16").tobytes() | |
# ββ 5.Β WebSocketβEndpoint βββββββββββββββββββββββββββββββββββββββββββ | |
async def tts(ws: WebSocket): | |
await ws.accept() | |
try: | |
req = json.loads(await ws.receive_text()) | |
text = req.get("text","") | |
voice = req.get("voice","Jakob") | |
ids, attn = build_prompt(text, voice) | |
past = None | |
buf = [] | |
while True: | |
out = model.generate( | |
input_ids=ids if past is None else None, | |
attention_mask=attn if past is None else None, | |
past_key_values=past, | |
max_new_tokens=CHUNK_TOKENS, | |
logits_processor=[MASKER], | |
do_sample=True, temperature=0.7, top_p=0.95, | |
use_cache=True, | |
return_dict_in_generate=True, | |
) | |
past = out.past_key_values | |
newtok = out.sequences[0,-out.num_generated_tokens:].tolist() | |
for t in newtok: | |
if t==EOS_TOKEN: | |
raise StopIteration | |
if t==NEW_BLOCK_TOKEN: | |
buf.clear(); continue | |
buf.append(t-AUDIO_BASE) | |
if len(buf)==7: | |
await ws.send_bytes(decode_snac(buf)) | |
buf.clear() | |
# ab jetzt nur noch mit Cache weiterβgenerieren | |
ids, attn = None, None | |
except (StopIteration, WebSocketDisconnect): | |
pass | |
except Exception as e: | |
print("WSβError:", e) | |
await ws.close(code=1011) | |
finally: | |
if ws.client_state.name!="DISCONNECTED": | |
await ws.close() | |
# ββ 6.Β Lokaler Test βββββββββββββββββββββββββββββββββββββββββββββββββ | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app:app", host="0.0.0.0", port=7860) | |