Tomtom84's picture
Update app.py
d11cc63 verified
raw
history blame
16.2 kB
# app.py ──────────────────────────────────────────────────────────────
import os
import json
import torch
import asyncio
import traceback # Import traceback for better error logging
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StoppingCriteria, StoppingCriteriaList
# Import BaseStreamer for the interface
from transformers.generation.streamers import BaseStreamer
from snac import SNAC # Ensure you have 'pip install snac'
# --- Globals (populated in load_models) ---
tok = None
model = None
snac = None
masker = None
stopping_criteria = None
# actual_eos_token_id = None # Reverted to constant below
device = "cuda" if torch.cuda.is_available() else "cpu"
# 0) Login + Device ---------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
print("πŸ”‘ Logging in to Hugging Face Hub...")
login(HF_TOKEN)
# torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorch‑2.2‑Bug
# 1) Konstanten -------------------------------------------------------
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
START_TOKEN = 128259
NEW_BLOCK = 128257
# --- Reverted to using the hardcoded EOS token based on user belief ---
EOS_TOKEN = 128258
# --- End Reverted EOS Token ---
AUDIO_BASE = 128266
AUDIO_SPAN = 4096 * 7 # 28672 Codes
CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
# Create AUDIO_IDS on the correct device later in load_models
AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
# 2) Logit‑Mask -------------------------------------------------------
# Uses the constant EOS_TOKEN
class AudioMask(LogitsProcessor):
def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
super().__init__()
new_block_tensor = torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long)
eos_tensor = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long)
self.allow = torch.cat([new_block_tensor, audio_ids], dim=0)
self.eos = eos_tensor
self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0)
self.sent_blocks = 0
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
mask = torch.full_like(scores, float("-inf"))
mask[:, current_allow] = 0
return scores + mask
def reset(self):
self.sent_blocks = 0
# 3) StoppingCriteria fΓΌr EOS ---------------------------------------
# Uses the constant EOS_TOKEN
class EosStoppingCriteria(StoppingCriteria):
def __init__(self, eos_token_id: int):
self.eos_token_id = eos_token_id
# No warning needed here as we are intentionally using the constant
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if self.eos_token_id is None:
return False
if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
print(f"StoppingCriteria: EOS detected (ID: {self.eos_token_id}).") # Add log
return True
return False
# 4) Benutzerdefinierter AudioStreamer -------------------------------
class AudioStreamer(BaseStreamer):
# Pass the constant EOS_TOKEN here too
def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str, eos_token_id: int):
self.ws = ws
self.snac = snac_decoder
self.masker = audio_mask
self.loop = loop
self.device = target_device
self.eos_token_id = eos_token_id # Store constant EOS ID
self.buf: list[int] = []
self.tasks = set()
def _decode_block(self, block7: list[int]) -> bytes:
"""
Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
NOTE: Extracts base code value (0-4095) using modulo, assuming
input values represent (slot_offset + code_value).
Maps extracted values using the structure potentially correct for Kartoffel_Orpheus.
"""
if len(block7) != 7:
# print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
return b"" # Less verbose logging
try:
# --- Extract base code value (0 to CODEBOOK_SIZE-1) for each slot using modulo ---
code_val_0 = block7[0] % CODEBOOK_SIZE
code_val_1 = block7[1] % CODEBOOK_SIZE
code_val_2 = block7[2] % CODEBOOK_SIZE
code_val_3 = block7[3] % CODEBOOK_SIZE
code_val_4 = block7[4] % CODEBOOK_SIZE
code_val_5 = block7[5] % CODEBOOK_SIZE
code_val_6 = block7[6] % CODEBOOK_SIZE
# --- Map the extracted code values to the SNAC codebooks (l1, l2, l3) ---
l1 = [code_val_0]
l2 = [code_val_1, code_val_4]
l3 = [code_val_2, code_val_3, code_val_5, code_val_6]
except IndexError:
print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}")
return b""
except Exception as e_map:
print(f"Streamer Error: Exception during code value extraction/mapping: {e_map}. Block: {block7}")
return b""
# --- Convert lists to tensors on the correct device ---
try:
codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0)
codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0)
codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0)
codes = [codes_l1, codes_l2, codes_l3]
except Exception as e_tensor:
print(f"Streamer Error: Exception during tensor conversion: {e_tensor}. l1={l1}, l2={l2}, l3={l3}")
return b""
# --- Decode using SNAC ---
try:
with torch.no_grad():
audio = self.snac.decode(codes)[0]
except Exception as e_decode:
print(f"Streamer Error: Exception during snac.decode: {e_decode}")
# Add more details if needed, e.g., shapes: {[c.shape for c in codes]}
return b""
# --- Post-processing ---
try:
audio_np = audio.squeeze().detach().cpu().numpy()
audio_bytes = (audio_np * 32767).astype("int16").tobytes()
return audio_bytes
except Exception as e_post:
print(f"Streamer Error: Exception during post-processing: {e_post}. Audio tensor shape: {audio.shape}")
return b""
async def _send_audio_bytes(self, data: bytes):
"""Coroutine to send bytes over WebSocket."""
if not data:
return
try:
await self.ws.send_bytes(data)
except WebSocketDisconnect:
# This is expected if client disconnects first, don't log error
# print("Streamer: WebSocket disconnected during send.")
pass
except Exception as e:
if "Cannot call \"send\" once a close message has been sent" in str(e) or \
"Connection is closed" in str(e):
# This is expected if client disconnects during generation, suppress repetitive logs
pass
else:
print(f"Streamer: Error sending bytes: {e}")
def put(self, value: torch.LongTensor):
"""
Receives new token IDs (Tensor) from generate().
Processes tokens, decodes full blocks, and schedules sending.
"""
if value.numel() == 0:
return
new_token_ids = value.squeeze().cpu().tolist()
if isinstance(new_token_ids, int):
new_token_ids = [new_token_ids]
for t in new_token_ids:
# No need to check for EOS here, StoppingCriteria handles it
if t == NEW_BLOCK:
self.buf.clear()
continue
# Use the constant EOS_TOKEN for comparison if needed (e.g. for logging)
if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
self.buf.append(t - AUDIO_BASE) # Store value relative to base
# else: # Optionally log ignored tokens
# if t != self.eos_token_id: # Don't warn about the EOS token itself
# print(f"Streamer Warning: Ignoring unexpected token {t}")
if len(self.buf) == 7:
audio_bytes = self._decode_block(self.buf)
self.buf.clear()
if audio_bytes:
future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
self.tasks.add(future)
future.add_done_callback(self.tasks.discard)
if self.masker.sent_blocks == 0:
self.masker.sent_blocks = 1
def end(self):
"""Called by generate() when generation finishes."""
if len(self.buf) > 0:
print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
self.buf.clear()
pass
# 5) FastAPI App ------------------------------------------------------
app = FastAPI()
@app.on_event("startup")
async def load_models_startup():
# Keep global references, but EOS_TOKEN is now a constant again
global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU
print(f"πŸš€ Starting up on device: {device}")
print("⏳ Lade Modelle …", flush=True)
tok = AutoTokenizer.from_pretrained(REPO)
print("Tokenizer loaded.")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
print(f"SNAC loaded to {device}.")
model_dtype = torch.float32
if device == "cuda":
if torch.cuda.is_bf16_supported():
model_dtype = torch.bfloat16
print("Using bfloat16 for model.")
else:
model_dtype = torch.float16
print("Using float16 for model.")
model = AutoModelForCausalLM.from_pretrained(
REPO,
device_map={"": 0} if device == "cuda" else None,
torch_dtype=model_dtype,
low_cpu_mem_usage=True,
)
print(f"Model loaded to {model.device} with dtype {model.dtype}.")
model.eval()
# --- Print comparison for EOS token IDs but use the constant ---
conf_eos = model.config.eos_token_id
tok_eos = tok.eos_token_id
print(f"Model Config EOS ID: {conf_eos}")
print(f"Tokenizer EOS ID: {tok_eos}")
print(f"Using Constant EOS_TOKEN: {EOS_TOKEN}") # State the used constant
if conf_eos != EOS_TOKEN or tok_eos != EOS_TOKEN:
print(f"⚠️ WARNING: Constant EOS_TOKEN {EOS_TOKEN} differs from model/tokenizer IDs ({conf_eos}/{tok_eos}).")
# --- End EOS comparison ---
# Set pad_token_id if None (use the constant EOS)
if model.config.pad_token_id is None:
print(f"Setting model.config.pad_token_id to Constant EOS token ID ({EOS_TOKEN})")
model.config.pad_token_id = EOS_TOKEN
audio_ids_device = AUDIO_IDS_CPU.to(device)
# Pass the constant EOS_TOKEN to the mask
masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN)
print("AudioMask initialized.")
# Pass the constant EOS_TOKEN to the stopping criteria
stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)])
print("StoppingCriteria initialized.")
print("βœ… Modelle geladen und bereit!", flush=True)
@app.get("/")
def hello():
return {"status": "ok", "message": "TTS Service is running"}
# 6) Helper zum Prompt Bauen -------------------------------------------
def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]:
"""Builds the input_ids and attention_mask for the model."""
prompt_text = f"{voice}: {text}"
prompt_ids = tok(prompt_text, return_tensors="pt").input_ids.to(device)
input_ids = torch.cat([
torch.tensor([[START_TOKEN]], device=device, dtype=torch.long),
prompt_ids,
torch.tensor([[NEW_BLOCK]], device=device, dtype=torch.long)
], dim=1)
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
# 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
# No need for global actual_eos_token_id
await ws.accept()
print("πŸ”Œ Client connected")
streamer = None
main_loop = asyncio.get_running_loop()
try:
req_text = await ws.receive_text()
print(f"Received request: {req_text}")
req = json.loads(req_text)
text = req.get("text", "Hallo Welt, wie geht es dir heute?")
voice = req.get("voice", "Jakob")
if not text:
print("⚠️ Request text is empty.")
await ws.close(code=1003, reason="Text cannot be empty")
return
print(f"Generating audio for: '{text}' with voice '{voice}'")
ids, attn = build_prompt(text, voice)
masker.reset()
# Pass the constant EOS_TOKEN to streamer
streamer = AudioStreamer(ws, snac, masker, main_loop, device, EOS_TOKEN)
print("Starting generation in background thread...")
# Use sampling parameters with anti-repetition measures
await asyncio.to_thread(
model.generate,
input_ids=ids,
attention_mask=attn,
max_new_tokens=2500, # Or adjust as needed
logits_processor=[masker],
stopping_criteria=stopping_criteria,
# --- Sampling Parameters with Anti-Repetition ---
do_sample=True,
temperature=0.6, # Adjust if needed
top_p=0.9, # Adjust if needed
repetition_penalty=1.2, # Increased (experiment!)
no_repeat_ngram_size=4, # Added (experiment!)
# --- End Sampling Parameters ---
use_cache=True,
streamer=streamer,
eos_token_id=EOS_TOKEN # Explicitly pass constant EOS ID
)
print("Generation thread finished.")
except WebSocketDisconnect:
print("πŸ”Œ Client disconnected.")
except json.JSONDecodeError:
print("❌ Invalid JSON received.")
if ws.client_state.name == "CONNECTED":
await ws.close(code=1003, reason="Invalid JSON format")
except Exception as e:
error_details = traceback.format_exc()
print(f"❌ WS‑Error: {e}\n{error_details}", flush=True)
error_payload = json.dumps({"error": str(e)})
try:
if ws.client_state.name == "CONNECTED":
await ws.send_text(error_payload)
except Exception:
pass
if ws.client_state.name == "CONNECTED":
await ws.close(code=1011)
finally:
if streamer:
try:
streamer.end()
except Exception as e_end:
print(f"Error during streamer.end(): {e_end}")
print("Closing connection.")
if ws.client_state.name == "CONNECTED":
try:
await ws.close(code=1000)
except RuntimeError as e_close:
if "Cannot call \"send\"" not in str(e_close) and "Connection is closed" not in str(e_close):
print(f"Runtime error closing websocket: {e_close}")
except Exception as e_close_final:
print(f"Error closing websocket: {e_close_final}")
elif ws.client_state.name != "DISCONNECTED":
print(f"WebSocket final state: {ws.client_state.name}")
print("Connection closed.")
# 8) Dev‑Start --------------------------------------------------------
if __name__ == "__main__":
import uvicorn
print("Starting Uvicorn server...")
uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")