""" # Orpheus TTS Handler - Explanation & Deployment Guide This guide explains how to properly deploy the Orpheus TTS model with the custom handler on Hugging Face Inference Endpoints. ## The Problem Based on the error messages you're seeing: 1. Connection is working (you get responses) 2. But responses contain text rather than audio data 3. The response format is the standard HF format: [{"generated_text": "..."}] This indicates that your endpoint is running the standard text generation handler rather than the custom audio generation handler you've defined. ## Step 1: Properly package your handler Create a `handler.py` file with your custom handler code: """ import os import torch import numpy as np from transformers import AutoModelForCausalLM, AutoTokenizer from snac import SNAC import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, path=""): logger.info("Initializing Orpheus TTS handler") # Load the Orpheus model and tokenizer self.model_name = "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit" self.model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.bfloat16 ) # Move model to GPU if available self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) logger.info(f"Model loaded on {self.device}") # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) logger.info("Tokenizer loaded") # Load SNAC model for audio decoding try: self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") self.snac_model.to(self.device) logger.info("SNAC model loaded") except Exception as e: logger.error(f"Error loading SNAC: {str(e)}") raise # Special tokens self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human self.padding_token = 128263 self.start_audio_token = 128257 # Start of Audio token self.end_audio_token = 128258 # End of Audio token self._warmed_up = False logger.info("Handler initialization complete") def preprocess(self, data): """ Preprocess input data before inference """ logger.info(f"Preprocessing data: {type(data)}") # Handle health check if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"): logger.info("Health check detected") return {"health_check": True} # HF Inference API format: 'inputs' is the text, 'parameters' contains the config if isinstance(data, dict) and "inputs" in data: # Standard HF format text = data["inputs"] parameters = data.get("parameters", {}) else: # Direct access (fallback) text = data parameters = {} # Extract parameters from request voice = parameters.get("voice", "tara") temperature = float(parameters.get("temperature", 0.6)) top_p = float(parameters.get("top_p", 0.95)) max_new_tokens = int(parameters.get("max_new_tokens", 1200)) repetition_penalty = float(parameters.get("repetition_penalty", 1.1)) # Format prompt with voice prompt = f"{voice}: {text}" logger.info(f"Formatted prompt with voice {voice}") # Tokenize input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids # Add special tokens modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1) # No need for padding as we're processing a single sequence input_ids = modified_input_ids.to(self.device) attention_mask = torch.ones_like(input_ids) return { "input_ids": input_ids, "attention_mask": attention_mask, "temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "health_check": False } def inference(self, inputs): """ Run model inference on the preprocessed inputs """ # Handle health check if inputs.get("health_check", False): return {"status": "ok"} # Extract parameters input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] temperature = inputs["temperature"] top_p = inputs["top_p"] max_new_tokens = inputs["max_new_tokens"] repetition_penalty = inputs["repetition_penalty"] logger.info(f"Running inference with max_new_tokens={max_new_tokens}") # Generate output tokens with torch.no_grad(): generated_ids = self.model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, num_return_sequences=1, eos_token_id=self.end_audio_token, ) logger.info(f"Generation complete, output shape: {generated_ids.shape}") return generated_ids def postprocess(self, generated_ids): """ Process generated tokens into audio """ # Handle health check response if isinstance(generated_ids, dict) and "status" in generated_ids: return generated_ids logger.info("Postprocessing generated tokens") # Find Start of Audio token token_indices = (generated_ids == self.start_audio_token).nonzero(as_tuple=True) if len(token_indices[1]) > 0: last_occurrence_idx = token_indices[1][-1].item() cropped_tensor = generated_ids[:, last_occurrence_idx+1:] logger.info(f"Found start audio token at position {last_occurrence_idx}") else: cropped_tensor = generated_ids logger.warning("No start audio token found") # Remove End of Audio tokens processed_rows = [] for row in cropped_tensor: masked_row = row[row != self.end_audio_token] processed_rows.append(masked_row) # Prepare audio codes code_lists = [] for row in processed_rows: row_length = row.size(0) # Ensure length is multiple of 7 for SNAC new_length = (row_length // 7) * 7 trimmed_row = row[:new_length] trimmed_row = [t.item() - 128266 for t in trimmed_row] # Adjust token values code_lists.append(trimmed_row) # Generate audio from codes audio_samples = [] for code_list in code_lists: logger.info(f"Processing code list of length {len(code_list)}") if len(code_list) > 0: audio = self.redistribute_codes(code_list) audio_samples.append(audio) else: logger.warning("Empty code list, no audio to generate") if not audio_samples: logger.error("No audio samples generated") return {"error": "No audio samples generated"} # Return first (and only) audio sample audio_sample = audio_samples[0].detach().squeeze().cpu().numpy() # Convert to base64 for transmission import base64 import io import wave # Convert float32 array to int16 for WAV format audio_int16 = (audio_sample * 32767).astype(np.int16) # Create WAV in memory with io.BytesIO() as wav_io: with wave.open(wav_io, 'wb') as wav_file: wav_file.setnchannels(1) # Mono wav_file.setsampwidth(2) # 16-bit wav_file.setframerate(24000) # 24kHz wav_file.writeframes(audio_int16.tobytes()) wav_data = wav_io.getvalue() # Encode as base64 audio_b64 = base64.b64encode(wav_data).decode('utf-8') logger.info(f"Audio encoded as base64, length: {len(audio_b64)}") return { "audio_sample": audio_sample, "audio_b64": audio_b64, "sample_rate": 24000 } def redistribute_codes(self, code_list): """ Reorganize tokens for SNAC decoding """ layer_1 = [] # Coarsest layer layer_2 = [] # Intermediate layer layer_3 = [] # Finest layer num_groups = len(code_list) // 7 for i in range(num_groups): idx = 7 * i layer_1.append(code_list[idx]) layer_2.append(code_list[idx + 1] - 4096) layer_3.append(code_list[idx + 2] - (2 * 4096)) layer_3.append(code_list[idx + 3] - (3 * 4096)) layer_2.append(code_list[idx + 4] - (4 * 4096)) layer_3.append(code_list[idx + 5] - (5 * 4096)) layer_3.append(code_list[idx + 6] - (6 * 4096)) codes = [ torch.tensor(layer_1).unsqueeze(0).to(self.device), torch.tensor(layer_2).unsqueeze(0).to(self.device), torch.tensor(layer_3).unsqueeze(0).to(self.device) ] # Decode audio audio_hat = self.snac_model.decode(codes) return audio_hat def __call__(self, data): """ Main entry point for the handler """ # Run warmup only once, the first time __call__ is triggered if not self._warmed_up: self._warmup() try: logger.info(f"Received request: {type(data)}") # Check if we need to handle the health check route if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"): logger.info("Processing health check request") return {"status": "ok"} preprocessed_inputs = self.preprocess(data) model_outputs = self.inference(preprocessed_inputs) response = self.postprocess(model_outputs) return response except Exception as e: logger.error(f"Error processing request: {str(e)}") import traceback logger.error(traceback.format_exc()) return {"error": str(e)} def _warmup(self): try: dummy_prompt = "tara: Hello" input_ids = self.tokenizer(dummy_prompt, return_tensors="pt").input_ids.to(self.device) _ = self.model.generate(input_ids=input_ids, max_new_tokens=5) self._warmed_up = True except Exception as e: print(f"[WARMUP ERROR] {str(e)}")