Update handler.py
Browse files- handler.py +79 -9
handler.py
CHANGED
@@ -1,11 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
5 |
from snac import SNAC
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
class EndpointHandler:
|
8 |
def __init__(self, path=""):
|
|
|
9 |
# Load the Orpheus model and tokenizer
|
10 |
self.model_name = "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit"
|
11 |
self.model = AutoModelForCausalLM.from_pretrained(
|
@@ -16,13 +44,20 @@ class EndpointHandler:
|
|
16 |
# Move model to GPU if available
|
17 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
self.model.to(self.device)
|
|
|
19 |
|
20 |
# Load tokenizer
|
21 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
|
22 |
|
23 |
# Load SNAC model for audio decoding
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
# Special tokens
|
28 |
self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
|
@@ -31,14 +66,20 @@ class EndpointHandler:
|
|
31 |
self.start_audio_token = 128257 # Start of Audio token
|
32 |
self.end_audio_token = 128258 # End of Audio token
|
33 |
|
34 |
-
|
35 |
|
36 |
def preprocess(self, data):
|
37 |
"""
|
38 |
Preprocess input data before inference
|
39 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
# HF Inference API format: 'inputs' is the text, 'parameters' contains the config
|
41 |
-
# Handle both direct access and standardized HF format
|
42 |
if isinstance(data, dict) and "inputs" in data:
|
43 |
# Standard HF format
|
44 |
text = data["inputs"]
|
@@ -57,6 +98,7 @@ class EndpointHandler:
|
|
57 |
|
58 |
# Format prompt with voice
|
59 |
prompt = f"{voice}: {text}"
|
|
|
60 |
|
61 |
# Tokenize
|
62 |
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
@@ -74,13 +116,18 @@ class EndpointHandler:
|
|
74 |
"temperature": temperature,
|
75 |
"top_p": top_p,
|
76 |
"max_new_tokens": max_new_tokens,
|
77 |
-
"repetition_penalty": repetition_penalty
|
|
|
78 |
}
|
79 |
|
80 |
def inference(self, inputs):
|
81 |
"""
|
82 |
Run model inference on the preprocessed inputs
|
83 |
"""
|
|
|
|
|
|
|
|
|
84 |
# Extract parameters
|
85 |
input_ids = inputs["input_ids"]
|
86 |
attention_mask = inputs["attention_mask"]
|
@@ -89,6 +136,8 @@ class EndpointHandler:
|
|
89 |
max_new_tokens = inputs["max_new_tokens"]
|
90 |
repetition_penalty = inputs["repetition_penalty"]
|
91 |
|
|
|
|
|
92 |
# Generate output tokens
|
93 |
with torch.no_grad():
|
94 |
generated_ids = self.model.generate(
|
@@ -103,20 +152,29 @@ class EndpointHandler:
|
|
103 |
eos_token_id=self.end_audio_token,
|
104 |
)
|
105 |
|
|
|
106 |
return generated_ids
|
107 |
|
108 |
def postprocess(self, generated_ids):
|
109 |
"""
|
110 |
Process generated tokens into audio
|
111 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
# Find Start of Audio token
|
113 |
token_indices = (generated_ids == self.start_audio_token).nonzero(as_tuple=True)
|
114 |
|
115 |
if len(token_indices[1]) > 0:
|
116 |
last_occurrence_idx = token_indices[1][-1].item()
|
117 |
cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
|
|
|
118 |
else:
|
119 |
cropped_tensor = generated_ids
|
|
|
120 |
|
121 |
# Remove End of Audio tokens
|
122 |
processed_rows = []
|
@@ -137,8 +195,16 @@ class EndpointHandler:
|
|
137 |
# Generate audio from codes
|
138 |
audio_samples = []
|
139 |
for code_list in code_lists:
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
# Return first (and only) audio sample
|
144 |
audio_sample = audio_samples[0].detach().squeeze().cpu().numpy()
|
@@ -162,6 +228,7 @@ class EndpointHandler:
|
|
162 |
|
163 |
# Encode as base64
|
164 |
audio_b64 = base64.b64encode(wav_data).decode('utf-8')
|
|
|
165 |
|
166 |
return {
|
167 |
"audio_b64": audio_b64,
|
@@ -205,7 +272,8 @@ class EndpointHandler:
|
|
205 |
logger.info(f"Received request: {type(data)}")
|
206 |
|
207 |
# Check if we need to handle the health check route
|
208 |
-
if data == "ping" or data
|
|
|
209 |
return {"status": "ok"}
|
210 |
|
211 |
preprocessed_inputs = self.preprocess(data)
|
@@ -216,4 +284,6 @@ class EndpointHandler:
|
|
216 |
logger.error(f"Error processing request: {str(e)}")
|
217 |
import traceback
|
218 |
logger.error(traceback.format_exc())
|
219 |
-
return {"error": str(e)}
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Orpheus TTS Handler - Explanation & Deployment Guide
|
3 |
+
|
4 |
+
This guide explains how to properly deploy the Orpheus TTS model with the custom
|
5 |
+
handler on Hugging Face Inference Endpoints.
|
6 |
+
|
7 |
+
## The Problem
|
8 |
+
|
9 |
+
Based on the error messages you're seeing:
|
10 |
+
1. Connection is working (you get responses)
|
11 |
+
2. But responses contain text rather than audio data
|
12 |
+
3. The response format is the standard HF format: [{"generated_text": "..."}]
|
13 |
+
|
14 |
+
This indicates that your endpoint is running the standard text generation handler
|
15 |
+
rather than the custom audio generation handler you've defined.
|
16 |
+
|
17 |
+
## Step 1: Properly package your handler
|
18 |
+
|
19 |
+
Create a `handler.py` file with your custom handler code:
|
20 |
+
"""
|
21 |
+
|
22 |
+
# Code from your original handler, but with some fixes
|
23 |
import os
|
24 |
import torch
|
25 |
import numpy as np
|
26 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
27 |
from snac import SNAC
|
28 |
+
import logging
|
29 |
+
|
30 |
+
# Set up logging
|
31 |
+
logging.basicConfig(level=logging.INFO)
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
|
34 |
class EndpointHandler:
|
35 |
def __init__(self, path=""):
|
36 |
+
logger.info("Initializing Orpheus TTS handler")
|
37 |
# Load the Orpheus model and tokenizer
|
38 |
self.model_name = "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit"
|
39 |
self.model = AutoModelForCausalLM.from_pretrained(
|
|
|
44 |
# Move model to GPU if available
|
45 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
46 |
self.model.to(self.device)
|
47 |
+
logger.info(f"Model loaded on {self.device}")
|
48 |
|
49 |
# Load tokenizer
|
50 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
51 |
+
logger.info("Tokenizer loaded")
|
52 |
|
53 |
# Load SNAC model for audio decoding
|
54 |
+
try:
|
55 |
+
self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
|
56 |
+
self.snac_model.to(self.device)
|
57 |
+
logger.info("SNAC model loaded")
|
58 |
+
except Exception as e:
|
59 |
+
logger.error(f"Error loading SNAC: {str(e)}")
|
60 |
+
raise
|
61 |
|
62 |
# Special tokens
|
63 |
self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
|
|
|
66 |
self.start_audio_token = 128257 # Start of Audio token
|
67 |
self.end_audio_token = 128258 # End of Audio token
|
68 |
|
69 |
+
logger.info("Handler initialization complete")
|
70 |
|
71 |
def preprocess(self, data):
|
72 |
"""
|
73 |
Preprocess input data before inference
|
74 |
"""
|
75 |
+
logger.info(f"Preprocessing data: {type(data)}")
|
76 |
+
|
77 |
+
# Handle health check
|
78 |
+
if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"):
|
79 |
+
logger.info("Health check detected")
|
80 |
+
return {"health_check": True}
|
81 |
+
|
82 |
# HF Inference API format: 'inputs' is the text, 'parameters' contains the config
|
|
|
83 |
if isinstance(data, dict) and "inputs" in data:
|
84 |
# Standard HF format
|
85 |
text = data["inputs"]
|
|
|
98 |
|
99 |
# Format prompt with voice
|
100 |
prompt = f"{voice}: {text}"
|
101 |
+
logger.info(f"Formatted prompt with voice {voice}")
|
102 |
|
103 |
# Tokenize
|
104 |
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
|
|
116 |
"temperature": temperature,
|
117 |
"top_p": top_p,
|
118 |
"max_new_tokens": max_new_tokens,
|
119 |
+
"repetition_penalty": repetition_penalty,
|
120 |
+
"health_check": False
|
121 |
}
|
122 |
|
123 |
def inference(self, inputs):
|
124 |
"""
|
125 |
Run model inference on the preprocessed inputs
|
126 |
"""
|
127 |
+
# Handle health check
|
128 |
+
if inputs.get("health_check", False):
|
129 |
+
return {"status": "ok"}
|
130 |
+
|
131 |
# Extract parameters
|
132 |
input_ids = inputs["input_ids"]
|
133 |
attention_mask = inputs["attention_mask"]
|
|
|
136 |
max_new_tokens = inputs["max_new_tokens"]
|
137 |
repetition_penalty = inputs["repetition_penalty"]
|
138 |
|
139 |
+
logger.info(f"Running inference with max_new_tokens={max_new_tokens}")
|
140 |
+
|
141 |
# Generate output tokens
|
142 |
with torch.no_grad():
|
143 |
generated_ids = self.model.generate(
|
|
|
152 |
eos_token_id=self.end_audio_token,
|
153 |
)
|
154 |
|
155 |
+
logger.info(f"Generation complete, output shape: {generated_ids.shape}")
|
156 |
return generated_ids
|
157 |
|
158 |
def postprocess(self, generated_ids):
|
159 |
"""
|
160 |
Process generated tokens into audio
|
161 |
"""
|
162 |
+
# Handle health check response
|
163 |
+
if isinstance(generated_ids, dict) and "status" in generated_ids:
|
164 |
+
return generated_ids
|
165 |
+
|
166 |
+
logger.info("Postprocessing generated tokens")
|
167 |
+
|
168 |
# Find Start of Audio token
|
169 |
token_indices = (generated_ids == self.start_audio_token).nonzero(as_tuple=True)
|
170 |
|
171 |
if len(token_indices[1]) > 0:
|
172 |
last_occurrence_idx = token_indices[1][-1].item()
|
173 |
cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
|
174 |
+
logger.info(f"Found start audio token at position {last_occurrence_idx}")
|
175 |
else:
|
176 |
cropped_tensor = generated_ids
|
177 |
+
logger.warning("No start audio token found")
|
178 |
|
179 |
# Remove End of Audio tokens
|
180 |
processed_rows = []
|
|
|
195 |
# Generate audio from codes
|
196 |
audio_samples = []
|
197 |
for code_list in code_lists:
|
198 |
+
logger.info(f"Processing code list of length {len(code_list)}")
|
199 |
+
if len(code_list) > 0:
|
200 |
+
audio = self.redistribute_codes(code_list)
|
201 |
+
audio_samples.append(audio)
|
202 |
+
else:
|
203 |
+
logger.warning("Empty code list, no audio to generate")
|
204 |
+
|
205 |
+
if not audio_samples:
|
206 |
+
logger.error("No audio samples generated")
|
207 |
+
return {"error": "No audio samples generated"}
|
208 |
|
209 |
# Return first (and only) audio sample
|
210 |
audio_sample = audio_samples[0].detach().squeeze().cpu().numpy()
|
|
|
228 |
|
229 |
# Encode as base64
|
230 |
audio_b64 = base64.b64encode(wav_data).decode('utf-8')
|
231 |
+
logger.info(f"Audio encoded as base64, length: {len(audio_b64)}")
|
232 |
|
233 |
return {
|
234 |
"audio_b64": audio_b64,
|
|
|
272 |
logger.info(f"Received request: {type(data)}")
|
273 |
|
274 |
# Check if we need to handle the health check route
|
275 |
+
if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"):
|
276 |
+
logger.info("Processing health check request")
|
277 |
return {"status": "ok"}
|
278 |
|
279 |
preprocessed_inputs = self.preprocess(data)
|
|
|
284 |
logger.error(f"Error processing request: {str(e)}")
|
285 |
import traceback
|
286 |
logger.error(traceback.format_exc())
|
287 |
+
return {"error": str(e)}
|
288 |
+
|
289 |
+
"
|