hypaai's picture
Update handler.py
fac0887 verified
raw
history blame
11.3 kB
"""
# Orpheus TTS Handler - Explanation & Deployment Guide
This guide explains how to properly deploy the Orpheus TTS model with the custom
handler on Hugging Face Inference Endpoints.
## The Problem
Based on the error messages you're seeing:
1. Connection is working (you get responses)
2. But responses contain text rather than audio data
3. The response format is the standard HF format: [{"generated_text": "..."}]
This indicates that your endpoint is running the standard text generation handler
rather than the custom audio generation handler you've defined.
## Step 1: Properly package your handler
Create a `handler.py` file with your custom handler code:
"""
import os
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from snac import SNAC
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path=""):
logger.info("Initializing Orpheus TTS handler")
# Load the Orpheus model and tokenizer
self.model_name = "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit"
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16
)
# Move model to GPU if available
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
logger.info(f"Model loaded on {self.device}")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
logger.info("Tokenizer loaded")
# Load SNAC model for audio decoding
try:
self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
self.snac_model.to(self.device)
logger.info("SNAC model loaded")
except Exception as e:
logger.error(f"Error loading SNAC: {str(e)}")
raise
# Special tokens
self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
self.padding_token = 128263
self.start_audio_token = 128257 # Start of Audio token
self.end_audio_token = 128258 # End of Audio token
self._warmed_up = False
logger.info("Handler initialization complete")
def preprocess(self, data):
"""
Preprocess input data before inference
"""
logger.info(f"Preprocessing data: {type(data)}")
# Handle health check
if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"):
logger.info("Health check detected")
return {"health_check": True}
# HF Inference API format: 'inputs' is the text, 'parameters' contains the config
if isinstance(data, dict) and "inputs" in data:
# Standard HF format
text = data["inputs"]
parameters = data.get("parameters", {})
else:
# Direct access (fallback)
text = data
parameters = {}
# Extract parameters from request
voice = parameters.get("voice", "tara")
temperature = float(parameters.get("temperature", 0.6))
top_p = float(parameters.get("top_p", 0.95))
max_new_tokens = int(parameters.get("max_new_tokens", 1200))
repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
# Format prompt with voice
prompt = f"{voice}: {text}"
logger.info(f"Formatted prompt with voice {voice}")
# Tokenize
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
# Add special tokens
modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1)
# No need for padding as we're processing a single sequence
input_ids = modified_input_ids.to(self.device)
attention_mask = torch.ones_like(input_ids)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty,
"health_check": False
}
def inference(self, inputs):
"""
Run model inference on the preprocessed inputs
"""
# Handle health check
if inputs.get("health_check", False):
return {"status": "ok"}
# Extract parameters
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
temperature = inputs["temperature"]
top_p = inputs["top_p"]
max_new_tokens = inputs["max_new_tokens"]
repetition_penalty = inputs["repetition_penalty"]
logger.info(f"Running inference with max_new_tokens={max_new_tokens}")
# Generate output tokens
with torch.no_grad():
generated_ids = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
num_return_sequences=1,
eos_token_id=self.end_audio_token,
)
logger.info(f"Generation complete, output shape: {generated_ids.shape}")
return generated_ids
def postprocess(self, generated_ids):
"""
Process generated tokens into audio
"""
# Handle health check response
if isinstance(generated_ids, dict) and "status" in generated_ids:
return generated_ids
logger.info("Postprocessing generated tokens")
# Find Start of Audio token
token_indices = (generated_ids == self.start_audio_token).nonzero(as_tuple=True)
if len(token_indices[1]) > 0:
last_occurrence_idx = token_indices[1][-1].item()
cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
logger.info(f"Found start audio token at position {last_occurrence_idx}")
else:
cropped_tensor = generated_ids
logger.warning("No start audio token found")
# Remove End of Audio tokens
processed_rows = []
for row in cropped_tensor:
masked_row = row[row != self.end_audio_token]
processed_rows.append(masked_row)
# Prepare audio codes
code_lists = []
for row in processed_rows:
row_length = row.size(0)
# Ensure length is multiple of 7 for SNAC
new_length = (row_length // 7) * 7
trimmed_row = row[:new_length]
trimmed_row = [t.item() - 128266 for t in trimmed_row] # Adjust token values
code_lists.append(trimmed_row)
# Generate audio from codes
audio_samples = []
for code_list in code_lists:
logger.info(f"Processing code list of length {len(code_list)}")
if len(code_list) > 0:
audio = self.redistribute_codes(code_list)
audio_samples.append(audio)
else:
logger.warning("Empty code list, no audio to generate")
if not audio_samples:
logger.error("No audio samples generated")
return {"error": "No audio samples generated"}
# Return first (and only) audio sample
audio_sample = audio_samples[0].detach().squeeze().cpu().numpy()
# Convert to base64 for transmission
import base64
import io
import wave
# Convert float32 array to int16 for WAV format
audio_int16 = (audio_sample * 32767).astype(np.int16)
# Create WAV in memory
with io.BytesIO() as wav_io:
with wave.open(wav_io, 'wb') as wav_file:
wav_file.setnchannels(1) # Mono
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(24000) # 24kHz
wav_file.writeframes(audio_int16.tobytes())
wav_data = wav_io.getvalue()
# Encode as base64
audio_b64 = base64.b64encode(wav_data).decode('utf-8')
logger.info(f"Audio encoded as base64, length: {len(audio_b64)}")
return {
"audio_sample": audio_sample,
"audio_b64": audio_b64,
"sample_rate": 24000
}
def redistribute_codes(self, code_list):
"""
Reorganize tokens for SNAC decoding
"""
layer_1 = [] # Coarsest layer
layer_2 = [] # Intermediate layer
layer_3 = [] # Finest layer
num_groups = len(code_list) // 7
for i in range(num_groups):
idx = 7 * i
layer_1.append(code_list[idx])
layer_2.append(code_list[idx + 1] - 4096)
layer_3.append(code_list[idx + 2] - (2 * 4096))
layer_3.append(code_list[idx + 3] - (3 * 4096))
layer_2.append(code_list[idx + 4] - (4 * 4096))
layer_3.append(code_list[idx + 5] - (5 * 4096))
layer_3.append(code_list[idx + 6] - (6 * 4096))
codes = [
torch.tensor(layer_1).unsqueeze(0).to(self.device),
torch.tensor(layer_2).unsqueeze(0).to(self.device),
torch.tensor(layer_3).unsqueeze(0).to(self.device)
]
# Decode audio
audio_hat = self.snac_model.decode(codes)
return audio_hat
def __call__(self, data):
"""
Main entry point for the handler
"""
# Run warmup only once, the first time __call__ is triggered
if not self._warmed_up:
self._warmup()
try:
logger.info(f"Received request: {type(data)}")
# Check if we need to handle the health check route
if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"):
logger.info("Processing health check request")
return {"status": "ok"}
preprocessed_inputs = self.preprocess(data)
model_outputs = self.inference(preprocessed_inputs)
response = self.postprocess(model_outputs)
return response
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return {"error": str(e)}
def _warmup(self):
try:
dummy_prompt = "tara: Hello"
input_ids = self.tokenizer(dummy_prompt, return_tensors="pt").input_ids.to(self.device)
_ = self.model.generate(input_ids=input_ids, max_new_tokens=5)
self._warmed_up = True
except Exception as e:
print(f"[WARMUP ERROR] {str(e)}")