Spaces:
Paused
Paused
File size: 7,042 Bytes
5031731 4189fe1 9bf14d0 5031731 d9ea17d 0316ec3 5031731 479f253 2008a3f 1ab029d f92444a 5031731 479f253 5031731 479f253 5031731 479f253 5031731 f92444a 479f253 5031731 bca75ea 5031731 bca75ea 5031731 9bf14d0 0dfc310 9bf14d0 479f253 5031731 9bf14d0 d9ea17d 5031731 9bf14d0 5031731 bca75ea f63f843 5031731 bca75ea 5031731 f92444a 5031731 f92444a 5031731 479f253 f92444a 479f253 5031731 479f253 5031731 f92444a 5031731 f92444a 5031731 a8606ac bca75ea a09ea48 4189fe1 bca75ea 5031731 f63f843 479f253 5031731 bca75ea 479f253 f63f843 9ef5e61 5031731 f92444a 5031731 9ef5e61 5031731 479f253 9ef5e61 5031731 9ef5e61 5031731 479f253 5031731 9ef5e61 bca75ea 5031731 bca75ea 5031731 bca75ea 5031731 479f253 a09ea48 5031731 479f253 5031731 a4cfefc 5031731 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# app.py βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
import os, json, asyncio, torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import (AutoTokenizer, AutoModelForCausalLM, LogitsProcessor)
from transformers.generation.utils import Cache
from snac import SNAC
# ββ 0. HFβLogin & Device βββββββββββββββββββββββββββββββββββββββββββββ
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
#Β FlashβAttentionβBug in PyTorchΒ 2.2.x umgehen
torch.backends.cuda.enable_flash_sdp(False)
# ββ 1. Konstanten ββββββββββββββββββββββββββββββββββββββββββββββββββββ
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
CHUNK_TOKENS = 50 # pro miniβgenerate
START_TOKEN = 128259
NEW_BLOCK_TOKEN = 128257
EOS_TOKEN = 128258
AUDIO_BASE = 128266 # erster AudioβCode
VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)
# ββ 2. Dynamischer LogitβMasker ββββββββββββββββββββββββββββββββββββββ
class DynamicAudioMask(LogitsProcessor):
"""
blockt EOS, bis mindestens `min_audio_blocks` gesendet wurden
"""
def __init__(self, audio_ids: torch.Tensor, min_audio_blocks: int = 1):
super().__init__()
self.audio_ids = audio_ids
self.ctrl_ids = torch.tensor([NEW_BLOCK_TOKEN], device=audio_ids.device)
self.min_blocks = min_audio_blocks
self.blocks_done = 0
def __call__(self, input_ids, scores):
allowed = torch.cat([self.audio_ids, self.ctrl_ids])
if self.blocks_done >= self.min_blocks: # jetzt darf EOS dazu
allowed = torch.cat([allowed, torch.tensor([EOS_TOKEN], device=scores.device)])
mask = torch.full_like(scores, float("-inf"))
mask[:, allowed] = 0
return scores + mask
# ββ 3. FastAPI GrundgerΓΌst βββββββββββββββββββββββββββββββββββββββββββ
app = FastAPI()
@app.get("/")
async def ping():
return {"msg": "OrpheusβTTS up & running"}
@app.on_event("startup")
async def load_models():
global tok, model, snac, masker
print("β³Β Lade Modelle β¦")
tok = AutoTokenizer.from_pretrained(REPO)
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model = AutoModelForCausalLM.from_pretrained(
REPO,
low_cpu_mem_usage=True,
device_map={"": 0} if device == "cuda" else None,
torch_dtype=torch.bfloat16 if device == "cuda" else None,
)
model.config.pad_token_id = model.config.eos_token_id
model.config.use_cache = True
masker = DynamicAudioMask(VALID_AUDIO_IDS.to(device))
print("β
Β Modelle geladen")
# ββ 4. HilfsβFunktionen ββββββββββββββββββββββββββββββββββββββββββββββ
def build_inputs(text: str, voice: str):
prompt = f"{voice}: {text}"
ids = tok(prompt, return_tensors="pt").input_ids.to(device)
ids = torch.cat(
[ torch.tensor([[START_TOKEN]], device=device),
ids,
torch.tensor([[128009, 128260]], device=device) ],
dim=1,
)
attn = torch.ones_like(ids)
return ids, attn
def decode_block(block7: list[int]) -> bytes:
l1, l2, l3 = [], [], []
b = block7
l1.append(b[0])
l2.append(b[1] - 4096)
l3.extend([b[2] - 8192, b[3] - 12288])
l2.append(b[4] - 16384)
l3.extend([b[5] - 20480, b[6] - 24576])
codes = [
torch.tensor(l1, device=device).unsqueeze(0),
torch.tensor(l2, device=device).unsqueeze(0),
torch.tensor(l3, device=device).unsqueeze(0),
]
audio = snac.decode(codes).squeeze().cpu().numpy()
return (audio * 32767).astype("int16").tobytes()
# ββ 5. WebSocketβTTSβEndpoint βββββββββββββββββββββββββββββββββββββββ
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
await ws.accept()
try:
req = json.loads(await ws.receive_text())
text = req.get("text", "")
voice = req.get("voice", "Jakob")
ids, attn = build_inputs(text, voice)
past = None
buf = []
while True:
out = model.generate(
input_ids = ids if past is None else None,
attention_mask = attn if past is None else None,
past_key_values = past,
max_new_tokens = CHUNK_TOKENS,
logits_processor= [masker], # βΊ dynamischer Masker
do_sample=True, temperature=0.7, top_p=0.95,
return_dict_in_generate=True,
use_cache=True,
)
# Cache & neue Tokens extrahieren --------------------------------
pkv = out.past_key_values
if isinstance(pkv, Cache): # HFΒ >=Β 4.47
pkv = pkv.to_legacy()
past = pkv
new = out.sequences[0, -out.num_generated_tokens :].tolist()
print("new tokens:", new[:20]) # DebugβAusgabe
# ----------------------------------------------------------------
for t in new:
if t == EOS_TOKEN:
raise StopIteration
if t == NEW_BLOCK_TOKEN:
buf.clear()
continue
buf.append(t - AUDIO_BASE)
if len(buf) == 7:
await ws.send_bytes(decode_block(buf))
buf.clear()
masker.blocks_done += 1 # βΊΒ jetzt darf ggf. EOS
# nΓ€chsten generateβStep nur noch mit Cache, keine neuen ids
ids, attn = None, None
except (StopIteration, WebSocketDisconnect):
pass
except Exception as e:
print("βΒ WSβError:", e)
if ws.client_state.name != "DISCONNECTED":
await ws.close(code=1011)
finally:
if ws.client_state.name != "DISCONNECTED":
try:
await ws.close()
except RuntimeError:
pass # CloseβFrame war schon raus
# ββ 6. Lokaler Start (uvicorn) βββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|