Spaces:
Paused
Paused
# app.py ------------------------------------------------------------- | |
import os, json, torch | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from huggingface_hub import login | |
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor | |
from transformers.generation.utils import Cache | |
from snac import SNAC | |
# ββ 0. Auth & Device ββββββββββββββββββββββββββββββββββββββββββββββββ | |
if (tok := os.getenv("HF_TOKEN")): | |
login(tok) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch.backends.cuda.enable_flash_sdp(False) # PyTorchβ2.2 fix | |
# ββ 1. Konstanten βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" | |
CHUNK_TOKENS = 50 # β€Β 50Β βΒ <Β 1Β s Latenz | |
START_TOKEN = 128259 | |
NEW_BLOCK_TOKEN = 128257 | |
EOS_TOKEN = 128258 | |
AUDIO_BASE = 128266 | |
VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096) | |
# ββ 2. LogitβMaske (nur Audioβ und SteuerβToken) ββββββββββββββββββ | |
class AudioMask(LogitsProcessor): | |
def __init__(self, allowed: torch.Tensor): # allowed @device! | |
self.allowed = allowed | |
def __call__(self, _ids, scores): | |
mask = torch.full_like(scores, float("-inf")) | |
mask[:, self.allowed] = 0.0 | |
return scores + mask | |
ALLOWED_IDS = torch.cat( | |
[VALID_AUDIO_IDS, | |
torch.tensor([NEW_BLOCK_TOKEN, EOS_TOKEN])] | |
).to(device) | |
MASKER = AudioMask(ALLOWED_IDS) | |
# ββ 3. FastAPI GrundgerΓΌst ββββββββββββββββββββββββββββββββββββββββββ | |
app = FastAPI() | |
async def root(): | |
return {"msg": "OrpheusβTTS ready"} | |
# global handles | |
tok = model = snac = None | |
async def load_models(): | |
global tok, model, snac | |
tok = AutoTokenizer.from_pretrained(REPO) | |
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) | |
model = AutoModelForCausalLM.from_pretrained( | |
REPO, | |
low_cpu_mem_usage=True, | |
device_map={"": 0} if device == "cuda" else None, | |
torch_dtype=torch.bfloat16 if device == "cuda" else None, | |
) | |
model.config.pad_token_id = model.config.eos_token_id | |
model.config.use_cache = True | |
# ββ 4. Helper βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def build_inputs(text: str, voice: str): | |
prompt = f"{voice}: {text}" | |
ids = tok(prompt, return_tensors="pt").input_ids.to(device) | |
ids = torch.cat( | |
[ | |
torch.tensor([[START_TOKEN]], device=device), | |
ids, | |
torch.tensor([[128009, 128260]], device=device), | |
], | |
1, | |
) | |
return ids, torch.ones_like(ids) | |
def decode_block(b7: list[int]) -> bytes: | |
l1, l2, l3 = [], [], [] | |
l1.append(b7[0]) | |
l2.append(b7[1] - 4096) | |
l3.extend([b7[2] - 8192, b7[3] - 12288]) | |
l2.append(b7[4] - 16384) | |
l3.extend([b7[5] - 20480, b7[6] - 24576]) | |
codes = [torch.tensor(x, device=device).unsqueeze(0) for x in (l1, l2, l3)] | |
audio = snac.decode(codes).squeeze().cpu().numpy() | |
return (audio * 32767).astype("int16").tobytes() | |
def new_tokens_only(full_seq, prev_len): | |
"""liefert Liste der Tokens, die *neu* hinzukamen""" | |
return full_seq[prev_len:].tolist() | |
# ββ 5. WebSocketβEndpoint βββββββββββββββββββββββββββββββββββββββββββ | |
async def tts(ws: WebSocket): | |
await ws.accept() | |
try: | |
req = json.loads(await ws.receive_text()) | |
ids, attn = build_inputs(req.get("text", ""), req.get("voice", "Jakob")) | |
prompt_len = ids.size(1) # LΓ€nge des Prompts | |
past = None | |
buf = [] | |
while True: | |
gen = model.generate( | |
input_ids=ids if past is None else None, | |
attention_mask=attn if past is None else None, | |
past_key_values=past, | |
max_new_tokens=CHUNK_TOKENS, | |
logits_processor=[MASKER], | |
do_sample=True, top_p=0.95, temperature=0.7, | |
return_dict_in_generate=True, | |
use_cache=True, | |
return_legacy_cache=True, # wichtig <4.49 | |
) | |
# Cache fΓΌr den nΓ€chsten Loop | |
past = gen.past_key_values if not isinstance(gen.past_key_values, Cache) else gen.past_key_values.to_legacy() | |
seq = gen.sequences[0].tolist() | |
new_tok = new_tokens_only(seq, prompt_len) | |
prompt_len = len(seq) # nΓ€chstes Delta | |
if not new_tok: # (selten) nichts erzeugt β weiter | |
continue | |
for t in new_tok: | |
if t == EOS_TOKEN: | |
raise StopAsyncIteration | |
if t == NEW_BLOCK_TOKEN: | |
buf.clear() | |
continue | |
buf.append(t - AUDIO_BASE) | |
if len(buf) == 7: | |
await ws.send_bytes(decode_block(buf)) | |
buf.clear() | |
ids = None; attn = None # ab jetzt nur noch Cache | |
except (StopAsyncIteration, WebSocketDisconnect): | |
pass | |
except Exception as e: | |
print("WSβError:", e) | |
if ws.client_state.name == "CONNECTED": | |
await ws.close(code=1011) | |
finally: | |
if ws.client_state.name == "CONNECTED": | |
await ws.close() | |
# ββ 6. Local run ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
if __name__ == "__main__": | |
import uvicorn, sys | |
port = int(sys.argv[1]) if len(sys.argv) > 1 else 7860 | |
uvicorn.run("app:app", host="0.0.0.0", port=port) | |