hypaai commited on
Commit
47fb761
·
verified ·
1 Parent(s): 21ba720

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +79 -9
handler.py CHANGED
@@ -1,11 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
  import numpy as np
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from snac import SNAC
 
 
 
 
 
6
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
 
9
  # Load the Orpheus model and tokenizer
10
  self.model_name = "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit"
11
  self.model = AutoModelForCausalLM.from_pretrained(
@@ -16,13 +44,20 @@ class EndpointHandler:
16
  # Move model to GPU if available
17
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
18
  self.model.to(self.device)
 
19
 
20
  # Load tokenizer
21
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
 
22
 
23
  # Load SNAC model for audio decoding
24
- self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
25
- self.snac_model.to(self.device)
 
 
 
 
 
26
 
27
  # Special tokens
28
  self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
@@ -31,14 +66,20 @@ class EndpointHandler:
31
  self.start_audio_token = 128257 # Start of Audio token
32
  self.end_audio_token = 128258 # End of Audio token
33
 
34
- print(f"Model loaded on {self.device}")
35
 
36
  def preprocess(self, data):
37
  """
38
  Preprocess input data before inference
39
  """
 
 
 
 
 
 
 
40
  # HF Inference API format: 'inputs' is the text, 'parameters' contains the config
41
- # Handle both direct access and standardized HF format
42
  if isinstance(data, dict) and "inputs" in data:
43
  # Standard HF format
44
  text = data["inputs"]
@@ -57,6 +98,7 @@ class EndpointHandler:
57
 
58
  # Format prompt with voice
59
  prompt = f"{voice}: {text}"
 
60
 
61
  # Tokenize
62
  input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
@@ -74,13 +116,18 @@ class EndpointHandler:
74
  "temperature": temperature,
75
  "top_p": top_p,
76
  "max_new_tokens": max_new_tokens,
77
- "repetition_penalty": repetition_penalty
 
78
  }
79
 
80
  def inference(self, inputs):
81
  """
82
  Run model inference on the preprocessed inputs
83
  """
 
 
 
 
84
  # Extract parameters
85
  input_ids = inputs["input_ids"]
86
  attention_mask = inputs["attention_mask"]
@@ -89,6 +136,8 @@ class EndpointHandler:
89
  max_new_tokens = inputs["max_new_tokens"]
90
  repetition_penalty = inputs["repetition_penalty"]
91
 
 
 
92
  # Generate output tokens
93
  with torch.no_grad():
94
  generated_ids = self.model.generate(
@@ -103,20 +152,29 @@ class EndpointHandler:
103
  eos_token_id=self.end_audio_token,
104
  )
105
 
 
106
  return generated_ids
107
 
108
  def postprocess(self, generated_ids):
109
  """
110
  Process generated tokens into audio
111
  """
 
 
 
 
 
 
112
  # Find Start of Audio token
113
  token_indices = (generated_ids == self.start_audio_token).nonzero(as_tuple=True)
114
 
115
  if len(token_indices[1]) > 0:
116
  last_occurrence_idx = token_indices[1][-1].item()
117
  cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
 
118
  else:
119
  cropped_tensor = generated_ids
 
120
 
121
  # Remove End of Audio tokens
122
  processed_rows = []
@@ -137,8 +195,16 @@ class EndpointHandler:
137
  # Generate audio from codes
138
  audio_samples = []
139
  for code_list in code_lists:
140
- audio = self.redistribute_codes(code_list)
141
- audio_samples.append(audio)
 
 
 
 
 
 
 
 
142
 
143
  # Return first (and only) audio sample
144
  audio_sample = audio_samples[0].detach().squeeze().cpu().numpy()
@@ -162,6 +228,7 @@ class EndpointHandler:
162
 
163
  # Encode as base64
164
  audio_b64 = base64.b64encode(wav_data).decode('utf-8')
 
165
 
166
  return {
167
  "audio_b64": audio_b64,
@@ -205,7 +272,8 @@ class EndpointHandler:
205
  logger.info(f"Received request: {type(data)}")
206
 
207
  # Check if we need to handle the health check route
208
- if data == "ping" or data == {"inputs": "ping"}:
 
209
  return {"status": "ok"}
210
 
211
  preprocessed_inputs = self.preprocess(data)
@@ -216,4 +284,6 @@ class EndpointHandler:
216
  logger.error(f"Error processing request: {str(e)}")
217
  import traceback
218
  logger.error(traceback.format_exc())
219
- return {"error": str(e)}
 
 
 
1
+ """
2
+ # Orpheus TTS Handler - Explanation & Deployment Guide
3
+
4
+ This guide explains how to properly deploy the Orpheus TTS model with the custom
5
+ handler on Hugging Face Inference Endpoints.
6
+
7
+ ## The Problem
8
+
9
+ Based on the error messages you're seeing:
10
+ 1. Connection is working (you get responses)
11
+ 2. But responses contain text rather than audio data
12
+ 3. The response format is the standard HF format: [{"generated_text": "..."}]
13
+
14
+ This indicates that your endpoint is running the standard text generation handler
15
+ rather than the custom audio generation handler you've defined.
16
+
17
+ ## Step 1: Properly package your handler
18
+
19
+ Create a `handler.py` file with your custom handler code:
20
+ """
21
+
22
+ # Code from your original handler, but with some fixes
23
  import os
24
  import torch
25
  import numpy as np
26
  from transformers import AutoModelForCausalLM, AutoTokenizer
27
  from snac import SNAC
28
+ import logging
29
+
30
+ # Set up logging
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger(__name__)
33
 
34
  class EndpointHandler:
35
  def __init__(self, path=""):
36
+ logger.info("Initializing Orpheus TTS handler")
37
  # Load the Orpheus model and tokenizer
38
  self.model_name = "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit"
39
  self.model = AutoModelForCausalLM.from_pretrained(
 
44
  # Move model to GPU if available
45
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
46
  self.model.to(self.device)
47
+ logger.info(f"Model loaded on {self.device}")
48
 
49
  # Load tokenizer
50
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
51
+ logger.info("Tokenizer loaded")
52
 
53
  # Load SNAC model for audio decoding
54
+ try:
55
+ self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
56
+ self.snac_model.to(self.device)
57
+ logger.info("SNAC model loaded")
58
+ except Exception as e:
59
+ logger.error(f"Error loading SNAC: {str(e)}")
60
+ raise
61
 
62
  # Special tokens
63
  self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
 
66
  self.start_audio_token = 128257 # Start of Audio token
67
  self.end_audio_token = 128258 # End of Audio token
68
 
69
+ logger.info("Handler initialization complete")
70
 
71
  def preprocess(self, data):
72
  """
73
  Preprocess input data before inference
74
  """
75
+ logger.info(f"Preprocessing data: {type(data)}")
76
+
77
+ # Handle health check
78
+ if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"):
79
+ logger.info("Health check detected")
80
+ return {"health_check": True}
81
+
82
  # HF Inference API format: 'inputs' is the text, 'parameters' contains the config
 
83
  if isinstance(data, dict) and "inputs" in data:
84
  # Standard HF format
85
  text = data["inputs"]
 
98
 
99
  # Format prompt with voice
100
  prompt = f"{voice}: {text}"
101
+ logger.info(f"Formatted prompt with voice {voice}")
102
 
103
  # Tokenize
104
  input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
 
116
  "temperature": temperature,
117
  "top_p": top_p,
118
  "max_new_tokens": max_new_tokens,
119
+ "repetition_penalty": repetition_penalty,
120
+ "health_check": False
121
  }
122
 
123
  def inference(self, inputs):
124
  """
125
  Run model inference on the preprocessed inputs
126
  """
127
+ # Handle health check
128
+ if inputs.get("health_check", False):
129
+ return {"status": "ok"}
130
+
131
  # Extract parameters
132
  input_ids = inputs["input_ids"]
133
  attention_mask = inputs["attention_mask"]
 
136
  max_new_tokens = inputs["max_new_tokens"]
137
  repetition_penalty = inputs["repetition_penalty"]
138
 
139
+ logger.info(f"Running inference with max_new_tokens={max_new_tokens}")
140
+
141
  # Generate output tokens
142
  with torch.no_grad():
143
  generated_ids = self.model.generate(
 
152
  eos_token_id=self.end_audio_token,
153
  )
154
 
155
+ logger.info(f"Generation complete, output shape: {generated_ids.shape}")
156
  return generated_ids
157
 
158
  def postprocess(self, generated_ids):
159
  """
160
  Process generated tokens into audio
161
  """
162
+ # Handle health check response
163
+ if isinstance(generated_ids, dict) and "status" in generated_ids:
164
+ return generated_ids
165
+
166
+ logger.info("Postprocessing generated tokens")
167
+
168
  # Find Start of Audio token
169
  token_indices = (generated_ids == self.start_audio_token).nonzero(as_tuple=True)
170
 
171
  if len(token_indices[1]) > 0:
172
  last_occurrence_idx = token_indices[1][-1].item()
173
  cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
174
+ logger.info(f"Found start audio token at position {last_occurrence_idx}")
175
  else:
176
  cropped_tensor = generated_ids
177
+ logger.warning("No start audio token found")
178
 
179
  # Remove End of Audio tokens
180
  processed_rows = []
 
195
  # Generate audio from codes
196
  audio_samples = []
197
  for code_list in code_lists:
198
+ logger.info(f"Processing code list of length {len(code_list)}")
199
+ if len(code_list) > 0:
200
+ audio = self.redistribute_codes(code_list)
201
+ audio_samples.append(audio)
202
+ else:
203
+ logger.warning("Empty code list, no audio to generate")
204
+
205
+ if not audio_samples:
206
+ logger.error("No audio samples generated")
207
+ return {"error": "No audio samples generated"}
208
 
209
  # Return first (and only) audio sample
210
  audio_sample = audio_samples[0].detach().squeeze().cpu().numpy()
 
228
 
229
  # Encode as base64
230
  audio_b64 = base64.b64encode(wav_data).decode('utf-8')
231
+ logger.info(f"Audio encoded as base64, length: {len(audio_b64)}")
232
 
233
  return {
234
  "audio_b64": audio_b64,
 
272
  logger.info(f"Received request: {type(data)}")
273
 
274
  # Check if we need to handle the health check route
275
+ if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"):
276
+ logger.info("Processing health check request")
277
  return {"status": "ok"}
278
 
279
  preprocessed_inputs = self.preprocess(data)
 
284
  logger.error(f"Error processing request: {str(e)}")
285
  import traceback
286
  logger.error(traceback.format_exc())
287
+ return {"error": str(e)}
288
+
289
+ "