|
""" |
|
# Orpheus TTS Handler - Explanation & Deployment Guide |
|
|
|
This guide explains how to properly deploy the Orpheus TTS model with the custom |
|
handler on Hugging Face Inference Endpoints. |
|
|
|
## The Problem |
|
|
|
Based on the error messages you're seeing: |
|
1. Connection is working (you get responses) |
|
2. But responses contain text rather than audio data |
|
3. The response format is the standard HF format: [{"generated_text": "..."}] |
|
|
|
This indicates that your endpoint is running the standard text generation handler |
|
rather than the custom audio generation handler you've defined. |
|
|
|
## Step 1: Properly package your handler |
|
|
|
Create a `handler.py` file with your custom handler code: |
|
""" |
|
|
|
import os |
|
import torch |
|
import numpy as np |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from snac import SNAC |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
logger.info("Initializing Orpheus TTS handler") |
|
|
|
self.model_name = "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit" |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
self.model_name, |
|
torch_dtype=torch.bfloat16 |
|
) |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model.to(self.device) |
|
logger.info(f"Model loaded on {self.device}") |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
logger.info("Tokenizer loaded") |
|
|
|
|
|
try: |
|
self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") |
|
self.snac_model.to(self.device) |
|
logger.info("SNAC model loaded") |
|
except Exception as e: |
|
logger.error(f"Error loading SNAC: {str(e)}") |
|
raise |
|
|
|
|
|
self.start_token = torch.tensor([[128259]], dtype=torch.int64) |
|
self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) |
|
self.padding_token = 128263 |
|
self.start_audio_token = 128257 |
|
self.end_audio_token = 128258 |
|
|
|
self._warmed_up = False |
|
|
|
logger.info("Handler initialization complete") |
|
|
|
def preprocess(self, data): |
|
""" |
|
Preprocess input data before inference |
|
""" |
|
logger.info(f"Preprocessing data: {type(data)}") |
|
|
|
|
|
if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"): |
|
logger.info("Health check detected") |
|
return {"health_check": True} |
|
|
|
|
|
if isinstance(data, dict) and "inputs" in data: |
|
|
|
text = data["inputs"] |
|
parameters = data.get("parameters", {}) |
|
else: |
|
|
|
text = data |
|
parameters = {} |
|
|
|
|
|
voice = parameters.get("voice", "tara") |
|
temperature = float(parameters.get("temperature", 0.6)) |
|
top_p = float(parameters.get("top_p", 0.95)) |
|
max_new_tokens = int(parameters.get("max_new_tokens", 1200)) |
|
repetition_penalty = float(parameters.get("repetition_penalty", 1.1)) |
|
|
|
|
|
prompt = f"{voice}: {text}" |
|
logger.info(f"Formatted prompt with voice {voice}") |
|
|
|
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids |
|
|
|
|
|
modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1) |
|
|
|
|
|
input_ids = modified_input_ids.to(self.device) |
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"max_new_tokens": max_new_tokens, |
|
"repetition_penalty": repetition_penalty, |
|
"health_check": False |
|
} |
|
|
|
def inference(self, inputs): |
|
""" |
|
Run model inference on the preprocessed inputs |
|
""" |
|
|
|
if inputs.get("health_check", False): |
|
return {"status": "ok"} |
|
|
|
|
|
input_ids = inputs["input_ids"] |
|
attention_mask = inputs["attention_mask"] |
|
temperature = inputs["temperature"] |
|
top_p = inputs["top_p"] |
|
max_new_tokens = inputs["max_new_tokens"] |
|
repetition_penalty = inputs["repetition_penalty"] |
|
|
|
logger.info(f"Running inference with max_new_tokens={max_new_tokens}") |
|
|
|
|
|
with torch.no_grad(): |
|
generated_ids = self.model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
num_return_sequences=1, |
|
eos_token_id=self.end_audio_token, |
|
) |
|
|
|
logger.info(f"Generation complete, output shape: {generated_ids.shape}") |
|
return generated_ids |
|
|
|
def postprocess(self, generated_ids): |
|
""" |
|
Process generated tokens into audio |
|
""" |
|
|
|
if isinstance(generated_ids, dict) and "status" in generated_ids: |
|
return generated_ids |
|
|
|
logger.info("Postprocessing generated tokens") |
|
|
|
|
|
token_indices = (generated_ids == self.start_audio_token).nonzero(as_tuple=True) |
|
|
|
if len(token_indices[1]) > 0: |
|
last_occurrence_idx = token_indices[1][-1].item() |
|
cropped_tensor = generated_ids[:, last_occurrence_idx+1:] |
|
logger.info(f"Found start audio token at position {last_occurrence_idx}") |
|
else: |
|
cropped_tensor = generated_ids |
|
logger.warning("No start audio token found") |
|
|
|
|
|
processed_rows = [] |
|
for row in cropped_tensor: |
|
masked_row = row[row != self.end_audio_token] |
|
processed_rows.append(masked_row) |
|
|
|
|
|
code_lists = [] |
|
for row in processed_rows: |
|
row_length = row.size(0) |
|
|
|
new_length = (row_length // 7) * 7 |
|
trimmed_row = row[:new_length] |
|
trimmed_row = [t.item() - 128266 for t in trimmed_row] |
|
code_lists.append(trimmed_row) |
|
|
|
|
|
audio_samples = [] |
|
for code_list in code_lists: |
|
logger.info(f"Processing code list of length {len(code_list)}") |
|
if len(code_list) > 0: |
|
audio = self.redistribute_codes(code_list) |
|
audio_samples.append(audio) |
|
else: |
|
logger.warning("Empty code list, no audio to generate") |
|
|
|
if not audio_samples: |
|
logger.error("No audio samples generated") |
|
return {"error": "No audio samples generated"} |
|
|
|
|
|
audio_sample = audio_samples[0].detach().squeeze().cpu().numpy() |
|
|
|
|
|
import base64 |
|
import io |
|
import wave |
|
|
|
|
|
audio_int16 = (audio_sample * 32767).astype(np.int16) |
|
|
|
|
|
with io.BytesIO() as wav_io: |
|
with wave.open(wav_io, 'wb') as wav_file: |
|
wav_file.setnchannels(1) |
|
wav_file.setsampwidth(2) |
|
wav_file.setframerate(24000) |
|
wav_file.writeframes(audio_int16.tobytes()) |
|
wav_data = wav_io.getvalue() |
|
|
|
|
|
audio_b64 = base64.b64encode(wav_data).decode('utf-8') |
|
logger.info(f"Audio encoded as base64, length: {len(audio_b64)}") |
|
|
|
return { |
|
"audio_sample": audio_sample, |
|
"audio_b64": audio_b64, |
|
"sample_rate": 24000 |
|
} |
|
|
|
def redistribute_codes(self, code_list): |
|
""" |
|
Reorganize tokens for SNAC decoding |
|
""" |
|
layer_1 = [] |
|
layer_2 = [] |
|
layer_3 = [] |
|
|
|
num_groups = len(code_list) // 7 |
|
for i in range(num_groups): |
|
idx = 7 * i |
|
layer_1.append(code_list[idx]) |
|
layer_2.append(code_list[idx + 1] - 4096) |
|
layer_3.append(code_list[idx + 2] - (2 * 4096)) |
|
layer_3.append(code_list[idx + 3] - (3 * 4096)) |
|
layer_2.append(code_list[idx + 4] - (4 * 4096)) |
|
layer_3.append(code_list[idx + 5] - (5 * 4096)) |
|
layer_3.append(code_list[idx + 6] - (6 * 4096)) |
|
|
|
codes = [ |
|
torch.tensor(layer_1).unsqueeze(0).to(self.device), |
|
torch.tensor(layer_2).unsqueeze(0).to(self.device), |
|
torch.tensor(layer_3).unsqueeze(0).to(self.device) |
|
] |
|
|
|
|
|
audio_hat = self.snac_model.decode(codes) |
|
return audio_hat |
|
|
|
def __call__(self, data): |
|
""" |
|
Main entry point for the handler |
|
""" |
|
|
|
if not self._warmed_up: |
|
self._warmup() |
|
|
|
try: |
|
logger.info(f"Received request: {type(data)}") |
|
|
|
|
|
if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"): |
|
logger.info("Processing health check request") |
|
return {"status": "ok"} |
|
|
|
preprocessed_inputs = self.preprocess(data) |
|
model_outputs = self.inference(preprocessed_inputs) |
|
response = self.postprocess(model_outputs) |
|
return response |
|
except Exception as e: |
|
logger.error(f"Error processing request: {str(e)}") |
|
import traceback |
|
logger.error(traceback.format_exc()) |
|
return {"error": str(e)} |
|
|
|
def _warmup(self): |
|
try: |
|
dummy_prompt = "tara: Hello" |
|
input_ids = self.tokenizer(dummy_prompt, return_tensors="pt").input_ids.to(self.device) |
|
_ = self.model.generate(input_ids=input_ids, max_new_tokens=5) |
|
self._warmed_up = True |
|
except Exception as e: |
|
print(f"[WARMUP ERROR] {str(e)}") |