Spaces:
Paused
Paused
# app.py ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
import os | |
import json | |
import torch | |
import asyncio | |
import traceback # Import traceback for better error logging | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from huggingface_hub import login | |
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StoppingCriteria, StoppingCriteriaList | |
# Import BaseStreamer for the interface | |
from transformers.generation.streamers import BaseStreamer | |
from snac import SNAC # Ensure you have 'pip install snac' | |
# --- Globals (populated in load_models) --- | |
tok = None | |
model = None | |
snac = None | |
masker = None | |
stopping_criteria = None | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# 0) Login + Device --------------------------------------------------- | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if HF_TOKEN: | |
print("π Logging in to Hugging Face Hub...") | |
# Consider adding error handling for login failure if necessary | |
login(HF_TOKEN) | |
# torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorchβ2.2βBug | |
# 1) Konstanten ------------------------------------------------------- | |
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" | |
START_TOKEN = 128259 | |
NEW_BLOCK = 128257 # Token indicating start of audio generation | |
EOS_TOKEN = 128258 # End Of Speech token | |
AUDIO_BASE = 128266 # Base ID for audio tokens | |
AUDIO_SPAN = 4096 * 7 # 7 codebooks * 4096 codes per book = 28672 possible audio tokens | |
# Create AUDIO_IDS on the correct device later in load_models | |
AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN) | |
# 2) LogitβMask ------------------------------------------------------- | |
class AudioMask(LogitsProcessor): | |
""" | |
Manages allowed tokens during generation. | |
- Initially allows NEW_BLOCK and AUDIO tokens. | |
- Allows EOS_TOKEN only after at least one audio block has been sent. | |
""" | |
def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int): | |
super().__init__() | |
# Allow NEW_BLOCK and all valid audio tokens initially | |
self.allow_initial = torch.cat([ | |
torch.tensor([new_block_token_id], device=audio_ids.device), | |
audio_ids | |
], dim=0) | |
self.eos = torch.tensor([eos_token_id], device=audio_ids.device) | |
# Precompute combined tensor for allowing audio, NEW_BLOCK, and EOS | |
self.allow_with_eos = torch.cat([self.allow_initial, self.eos], dim=0) | |
self.sent_blocks = 0 # State: Number of audio blocks sent | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
# Determine which tokens are allowed based on whether blocks have been sent | |
current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow_initial | |
# Create a mask initialized to negative infinity | |
mask = torch.full_like(scores, float("-inf")) | |
# Set allowed token scores to 0 (effectively allowing them) | |
mask[:, current_allow] = 0 | |
# Apply the mask to the scores | |
return scores + mask | |
def reset(self): | |
"""Resets the state for a new generation request.""" | |
self.sent_blocks = 0 | |
# 3) StoppingCriteria fΓΌr EOS --------------------------------------- | |
# generate() needs explicit stopping criteria when using a streamer | |
class EosStoppingCriteria(StoppingCriteria): | |
def __init__(self, eos_token_id: int): | |
self.eos_token_id = eos_token_id | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
# Check if the *last* generated token is the EOS token | |
# Check input_ids shape to prevent index error on first token | |
if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id: | |
print("StoppingCriteria: EOS detected.") | |
return True | |
return False | |
# 4) Benutzerdefinierter AudioStreamer ------------------------------- | |
class AudioStreamer(BaseStreamer): | |
""" | |
Custom streamer to process audio tokens, decode them using SNAC, | |
and send audio bytes over a WebSocket. | |
""" | |
# Added target_device parameter | |
def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str): | |
self.ws = ws | |
self.snac = snac_decoder | |
self.masker = audio_mask # Reference to the mask to update sent_blocks | |
self.loop = loop # Event loop of the main thread for run_coroutine_threadsafe | |
# Use the passed target_device | |
self.device = target_device | |
self.buf: list[int] = [] # Buffer for audio token values (AUDIO_BASE subtracted) | |
self.tasks = set() # Keep track of pending send tasks | |
def _decode_block(self, block7: list[int]) -> bytes: | |
""" | |
Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes. | |
NOTE: The mapping from the 7 tokens to the 3 SNAC codebooks (l1, l2, l3) | |
is CRITICAL and based on the structure used by the specific model. | |
This implementation uses the mapping derived from the user's previous code. | |
If audio is distorted, try the alternative mapping commented out below. | |
""" | |
if len(block7) != 7: | |
print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.") | |
return b"" # Return empty bytes if block is incomplete | |
# --- Mapping based on user's previous version --- | |
try: | |
l1 = [block7[0]] # Index 0 | |
l2 = [block7[1], block7[4]] # Indices 1, 4 | |
l3 = [block7[2], block7[3], block7[5], block7[6]] # Indices 2, 3, 5, 6 | |
# --- Alternative Hypothesis Mapping (Try if above fails) --- | |
# l1 = [block7[0], block7[3], block7[6]] # Indices 0, 3, 6 | |
# l2 = [block7[1], block7[4]] # Indices 1, 4 | |
# l3 = [block7[2], block7[5]] # Indices 2, 5 | |
except IndexError as e: | |
print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}, Error: {e}") | |
return b"" | |
# Convert lists to tensors on the correct device | |
try: | |
codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0) | |
codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0) | |
codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0) | |
codes = [codes_l1, codes_l2, codes_l3] # List of tensors for SNAC | |
except Exception as e: | |
print(f"Streamer Error: Failed converting lists to tensors. Error: {e}") | |
return b"" | |
# Decode using SNAC | |
try: | |
with torch.no_grad(): | |
# Ensure snac_decoder is on the correct device already (done via .to(device)) | |
audio = self.snac.decode(codes)[0] # Decode expects list of tensors, result might have batch dim | |
except Exception as e: | |
print(f"Streamer Error: snac.decode failed. Input shapes: {[c.shape for c in codes]}. Error: {e}") | |
return b"" | |
# Squeeze, move to CPU, convert to numpy | |
audio_np = audio.squeeze().detach().cpu().numpy() | |
# Convert to 16-bit PCM bytes | |
audio_bytes = (audio_np * 32767).astype("int16").tobytes() | |
return audio_bytes | |
async def _send_audio_bytes(self, data: bytes): | |
"""Coroutine to send bytes over WebSocket.""" | |
if not data: # Don't send empty bytes | |
return | |
try: | |
await self.ws.send_bytes(data) | |
# print(f"Streamer: Sent {len(data)} audio bytes.") # Optional: Debug log | |
except WebSocketDisconnect: | |
print("Streamer: WebSocket disconnected during send.") | |
except Exception as e: | |
print(f"Streamer: Error sending bytes: {e}") | |
def put(self, value: torch.LongTensor): | |
""" | |
Receives new token IDs (Tensor) from generate() (runs in worker thread). | |
Processes tokens, decodes full blocks, and schedules sending via run_coroutine_threadsafe. | |
""" | |
# Ensure value is on CPU and flatten to a list of ints | |
if value.numel() == 0: | |
return | |
# Handle potential shape issues, ensure it's iterable | |
try: | |
new_token_ids = value.view(-1).tolist() | |
except Exception as e: | |
print(f"Streamer Error: Could not process incoming tensor: {value}, Error: {e}") | |
return | |
for t in new_token_ids: | |
if t == EOS_TOKEN: | |
# print("Streamer: EOS token encountered.") | |
# EOS is handled by StoppingCriteria, no action needed here except maybe logging. | |
break # Stop processing this batch if EOS is found | |
if t == NEW_BLOCK: |