Spaces:
Paused
Paused
File size: 8,294 Bytes
0b5b901 87012a8 4189fe1 9bf14d0 87012a8 d9ea17d 0316ec3 e3958ab 479f253 2008a3f 1ab029d e3958ab 83532d0 f4406f3 e3958ab 479f253 e3958ab 3d65908 e3958ab 3d65908 7d18470 bb5c241 e3958ab 3d65908 7d18470 e3958ab 9bf14d0 0dfc310 9bf14d0 e3958ab 9bf14d0 e3958ab 5031731 e3958ab 0b5b901 9bf14d0 5031731 e3958ab bca75ea d44e840 f63f843 e3958ab 0b5b901 7f32a0e e3958ab 9e2fbd8 e3958ab 9e2fbd8 e3958ab 0b5b901 e3958ab a8606ac d44e840 a09ea48 4189fe1 d44e840 e3958ab 94f10a6 f63f843 94f10a6 d4b7e0d e3958ab d4b7e0d 0238891 5031731 d4b7e0d 0238891 e3958ab 0238891 e3958ab d4b7e0d e3958ab 0b5b901 e96cc47 e3958ab 7d18470 0ca2533 7d18470 7f32a0e 7d18470 bca75ea 5031731 479f253 a09ea48 e3958ab 83532d0 5031731 479f253 5031731 e3958ab 5031731 e3958ab a4cfefc e3958ab 83532d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
# app.py ──────────────────────────────────────────────────────────────
import os, json, torch, asyncio
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
from snac import SNAC
# 0) Login + Device ---------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.enable_flash_sdp(False) # PyTorch‑2.2‑Bug
# 1) Konstanten -------------------------------------------------------
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
CHUNK_TOKENS = 50
START_TOKEN = 128259
NEW_BLOCK = 128257
EOS_TOKEN = 128258
AUDIO_BASE = 128266
AUDIO_SPAN = 4096 * 7 # 28 672 Codes
AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN) # Renamed VALID_AUDIO to AUDIO_IDS
# 2) Logit‑Mask (NEW_BLOCK + Audio; EOS erst nach 1. Block) ----------
class AudioMask(LogitsProcessor):
def __init__(self, audio_ids: torch.Tensor):
super().__init__()
self.allow = torch.cat([
torch.tensor([NEW_BLOCK], device=audio_ids.device),
audio_ids
])
self.eos = torch.tensor([EOS_TOKEN], device=audio_ids.device)
self.sent_blocks = 0
self.buffer_pos = 0 # Added buffer position
def __call__(self, input_ids, logits):
# Calculate allowed tokens based on buffer position
start_token = AUDIO_BASE + self.buffer_pos * 4096
end_token = start_token + 4096
allowed_audio = torch.arange(start_token, end_token, device=self.allow.device)
# Only allow NEW_BLOCK if buffer is full, otherwise only allow audio tokens
if self.buffer_pos == 7:
allowed = torch.cat([
torch.tensor([NEW_BLOCK], device=self.allow.device),
allowed_audio
])
else:
allowed = allowed_audio # Only allow audio tokens
if self.sent_blocks: # ab 1. Block EOS zulassen
allowed = torch.cat([allowed, self.eos])
mask = logits.new_full(logits.shape, float("-inf"))
mask = logits.new_full(logits.shape, float("-inf"))
mask[:, allowed] = 0
return logits + mask
# 3) FastAPI Grundgerüst ---------------------------------------------
app = FastAPI()
@app.get("/")
def hello():
return {"status": "ok"}
@app.on_event("startup")
def load_models():
global tok, model, snac, masker
print("⏳ Lade Modelle …", flush=True)
tok = AutoTokenizer.from_pretrained(REPO)
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model = AutoModelForCausalLM.from_pretrained(
REPO,
device_map={"": 0} if device == "cuda" else None,
torch_dtype=torch.bfloat16 if device == "cuda" else None,
low_cpu_mem_usage=True,
)
model.config.pad_token_id = model.config.eos_token_id
masker = AudioMask(AUDIO_IDS.to(device))
print("✅ Modelle geladen", flush=True)
# 4) Helper -----------------------------------------------------------
def build_prompt(text: str, voice: str):
prompt_ids = tok(f"{voice}: {text}", return_tensors="pt").input_ids.to(device)
ids = torch.cat([torch.tensor([[START_TOKEN]], device=device),
prompt_ids,
torch.tensor([[128009, 128260]], device=device)], 1)
attn = torch.ones_like(ids)
return ids, attn
def decode_block(block7: list[int]) -> bytes:
l1,l2,l3=[],[],[]
l1.append(block7[0] - (AUDIO_BASE + 0 * 4096)) # Subtract AUDIO_BASE + position 0 offset
l2.append(block7[1] - (AUDIO_BASE + 1 * 4096)) # Subtract AUDIO_BASE + position 1 offset
l3 += [block7[2] - (AUDIO_BASE + 2 * 4096), block7[3] - (AUDIO_BASE + 3 * 4096)] # Subtract AUDIO_BASE + position offsets
l2.append(block7[4] - (AUDIO_BASE + 4 * 4096)) # Subtract AUDIO_BASE + position 4 offset
l3 += [block7[5] - (AUDIO_BASE + 5 * 4096), block7[6] - (AUDIO_BASE + 6 * 4096)] # Subtract AUDIO_BASE + position offsets
with torch.no_grad():
codes = [torch.tensor(x, device=device).unsqueeze(0)
for x in (l1,l2,l3)]
audio = snac.decode(codes).squeeze().detach().cpu().numpy()
return (audio*32767).astype("int16").tobytes()
# 5) WebSocket‑Endpoint ----------------------------------------------
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
await ws.accept()
try:
req = json.loads(await ws.receive_text())
text = req.get("text", "")
voice = req.get("voice", "Jakob")
ids, attn = build_prompt(text, voice)
past = None
offset_len = ids.size(1) # wie viele Tokens existieren schon
last_tok = None
buf = []
# masker.buffer_pos = 0 # Removed initialization here
while True:
# Update buffer_pos based on current buffer length before generation
masker.buffer_pos = len(buf)
# --- Mini‑Generate (Cache Disabled for Debugging) -------------------------------------------
gen = model.generate(
input_ids = ids, # Always use full sequence
attention_mask = attn, # Always use full attention mask
# past_key_values= past, # Disabled cache
max_new_tokens = CHUNK_TOKENS,
logits_processor=[masker],
do_sample=True, temperature=0.7, top_p=0.95,
use_cache=False, # Disabled cache
return_dict_in_generate=True,
return_legacy_cache=True
)
# ----- neue Tokens heraus schneiden --------------------------
seq = gen.sequences[0].tolist()
new = seq[offset_len:]
if not new: # nichts -> fertig
break
offset_len += len(new)
# ----- Update ids and attn with the full sequence (Cache Disabled) ---------
ids = torch.tensor([seq], device=device) # Re-added
attn = torch.ones_like(ids) # Re-added
# past = gen.past_key_values # Disabled cache access
last_tok = new[-1]
print("new tokens:", new[:25], flush=True)
# ----- Token‑Handling ----------------------------------------
for t in new:
if t == EOS_TOKEN: # Re-enabled EOS check
raise StopIteration # Re-enabled EOS check
if t == NEW_BLOCK:
buf.clear()
continue
# Only append if it's an audio token
# Only append if it's an audio token
if t >= AUDIO_BASE and t < AUDIO_BASE + AUDIO_SPAN:
buf.append(t - AUDIO_BASE) # Append token relative to AUDIO_BASE
# masker.buffer_pos += 1 # Removed increment here
if len(buf) == 7:
await ws.send_bytes(decode_block(buf))
buf.clear()
masker.sent_blocks = 1 # ab jetzt EOS zulässig
# masker.buffer_pos = 0 # Removed reset here
else:
# Optional: Log unexpected tokens
print(f"DEBUG: Skipping non-audio token: {t}", flush=True)
except (StopIteration, WebSocketDisconnect):
pass
except Exception as e:
print("❌ WS‑Error:", e, flush=True)
import traceback
traceback.print_exc()
if ws.client_state.name != "DISCONNECTED":
await ws.close(code=1011)
finally:
if ws.client_state.name != "DISCONNECTED":
try:
await ws.close()
except RuntimeError:
pass
# 6) Dev‑Start --------------------------------------------------------
if __name__ == "__main__":
import uvicorn, sys
uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info") |