Spaces:
Paused
Paused
File size: 6,238 Bytes
479f253 4189fe1 9bf14d0 4520cbe d9ea17d 0316ec3 f92444a 479f253 2008a3f 1ab029d 479f253 f92444a 479f253 f92444a 479f253 f92444a 479f253 bca75ea 479f253 bca75ea 479f253 f92444a bca75ea f92444a 9bf14d0 0dfc310 9bf14d0 479f253 9bf14d0 d9ea17d bca75ea 479f253 9bf14d0 479f253 bca75ea f63f843 479f253 bca75ea 479f253 f92444a 479f253 f92444a 479f253 f92444a 479f253 f92444a a8606ac bca75ea a09ea48 4189fe1 bca75ea 9ef5e61 4c833ce f63f843 479f253 bca75ea 479f253 bca75ea 479f253 f63f843 9ef5e61 479f253 f92444a 479f253 9ef5e61 479f253 9ef5e61 479f253 4c833ce 9ef5e61 4c833ce 479f253 9ef5e61 bca75ea 479f253 bca75ea 4c833ce 479f253 a09ea48 bca75ea f92444a 479f253 f92444a 479f253 a4cfefc f92444a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# app.py ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
import os, json, asyncio, torch, logging
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
from transformers.generation.utils import Cache
from snac import SNAC
# ββ 0. Auth & Device ββββββββββββββββββββββββββββββββββββββββββββββββ
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.enable_flash_sdp(False) # FlashβBug umgehen
logging.getLogger("transformers.generation.utils").setLevel("ERROR")
# ββ 1. Konstanten βββββββββββββββββββββββββββββββββββββββββββββββββββ
MODEL_REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
CHUNK_TOKENS = 50
START_TOKEN = 128259 # <π >
NEW_BLOCK_TOKEN = 128257 # πβStart
EOS_TOKEN = 128258 # <eos>
PROMPT_END = [128009, 128260]
AUDIO_BASE = 128266
VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)
# ββ 2. LogitβMasker βββββββββββββββββββββββββββββββββββββββββββββββββ
class AudioMask(LogitsProcessor):
def __init__(self, allowed: torch.Tensor):
super().__init__()
self.allowed = allowed
def __call__(self, input_ids, scores):
mask = torch.full_like(scores, float("-inf"))
mask[:, self.allowed] = 0
return scores + mask
ALLOWED_IDS = torch.cat([
VALID_AUDIO_IDS,
torch.tensor([START_TOKEN, NEW_BLOCK_TOKEN, EOS_TOKEN])
]).to(device)
MASKER = AudioMask(ALLOWED_IDS)
# ββ 3. FastAPI GrundgerΓΌst ββββββββββββββββββββββββββββββββββββββββββ
app = FastAPI()
@app.get("/")
async def ping():
return {"message": "OrpheusβTTSΒ ready"}
@app.on_event("startup")
async def load_models():
global tok, model, snac
tok = AutoTokenizer.from_pretrained(MODEL_REPO)
model = AutoModelForCausalLM.from_pretrained(
MODEL_REPO,
low_cpu_mem_usage=True,
device_map={"": 0} if device == "cuda" else None,
torch_dtype=torch.bfloat16 if device == "cuda" else None,
)
model.config.pad_token_id = model.config.eos_token_id
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
# ββ 4. Hilfsfunktionen ββββββββββββββββββββββββββββββββββββββββββββββ
def build_inputs(text: str, voice: str):
prompt = f"{voice}: {text}" if voice and voice != "in_prompt" else text
ids = tok(prompt, return_tensors="pt").input_ids.to(device)
ids = torch.cat([
torch.tensor([[START_TOKEN]], device=device),
ids,
torch.tensor([PROMPT_END], device=device)
], 1)
mask = torch.ones_like(ids)
return ids, mask # shape (1,Β L)
def decode_block(block7: list[int]) -> bytes:
l1, l2, l3 = [], [], []
b = block7
l1.append(b[0])
l2.append(b[1] - 4096)
l3 += [b[2]-8192, b[3]-12288]
l2.append(b[4] - 16384)
l3 += [b[5]-20480, b[6]-24576]
codes = [torch.tensor(x, device=device).unsqueeze(0) for x in (l1, l2, l3)]
audio = snac.decode(codes).squeeze().cpu().numpy()
return (audio * 32767).astype("int16").tobytes()
# ββ 5. WebSocketβEndpoint βββββββββββββββββββββββββββββββββββββββββββ
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
await ws.accept()
try:
req = json.loads(await ws.receive_text())
ids, attn = build_inputs(req.get("text", ""), req.get("voice", "Jakob"))
prompt_len = ids.size(1)
past, buf = None, []
while True:
out = model.generate(
input_ids=ids if past is None else None,
attention_mask=attn if past is None else None,
past_key_values=past,
max_new_tokens=CHUNK_TOKENS,
logits_processor=[MASKER],
do_sample=True, top_p=0.95, temperature=0.7,
return_dict_in_generate=True,
use_cache=True,
return_legacy_cache=True, # β Warnung verschwindet
)
past = out.past_key_values # unverΓ€ndert weiterreichen
seq = out.sequences[0].tolist()
new = seq[prompt_len:]; prompt_len = len(seq)
if not new: # selten, aber mΓΆglich
continue
for t in new:
if t == EOS_TOKEN:
await ws.close()
return
if t == NEW_BLOCK_TOKEN:
buf.clear(); continue
if t < AUDIO_BASE: # sollte durch Maske nie passieren
continue
buf.append(t - AUDIO_BASE)
if len(buf) == 7:
await ws.send_bytes(decode_block(buf))
buf.clear()
# Ab jetzt nur noch Cache β IDs & Mask nicht mehr nΓΆtig
ids = attn = None
except WebSocketDisconnect:
pass
except Exception as e:
print("WSβError:", e)
if ws.client_state.name == "CONNECTED":
await ws.close(code=1011)
# ββ 6. Lokaler Start ββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
import uvicorn, sys
port = int(sys.argv[1]) if len(sys.argv) > 1 else 7860
uvicorn.run("app:app", host="0.0.0.0", port=port)
|