Update handler.py
Browse files- handler.py +273 -275
handler.py
CHANGED
@@ -19,147 +19,43 @@ rather than the custom audio generation handler you've defined.
|
|
19 |
Create a `handler.py` file with your custom handler code:
|
20 |
"""
|
21 |
|
22 |
-
import torch
|
23 |
-
import numpy as np
|
24 |
-
|
25 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
26 |
-
|
27 |
-
class EndpointHandler():
|
28 |
-
def __init__(self, path=""):
|
29 |
-
|
30 |
-
# Load the models and tokenizer
|
31 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
32 |
-
"hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit",
|
33 |
-
torch_dtype=torch.bfloat16
|
34 |
-
)
|
35 |
-
self.tokenizer = AutoTokenizer.from_pretrained("hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit")
|
36 |
-
|
37 |
-
# Move to devices
|
38 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
39 |
-
self.model.to(self.device)
|
40 |
-
|
41 |
-
# Special tokens
|
42 |
-
self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
|
43 |
-
self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
|
44 |
-
self.padding_token = 128263
|
45 |
-
self.start_audio_token = 128257 # Start of Audio token
|
46 |
-
self.end_audio_token = 128258 # End of Audio token
|
47 |
-
|
48 |
-
|
49 |
-
def __call__(self, data):
|
50 |
-
"""
|
51 |
-
Main entry point for the handler
|
52 |
-
"""
|
53 |
-
|
54 |
-
# Preprocess input
|
55 |
-
if isinstance(data, dict) and "inputs" in data:
|
56 |
-
text = data["inputs"]
|
57 |
-
parameters = data.get("parameters", {})
|
58 |
-
else:
|
59 |
-
text = data
|
60 |
-
parameters = {}
|
61 |
-
|
62 |
-
# Extract parameters from request
|
63 |
-
voice = parameters.get("voice", "tara")
|
64 |
-
temperature = float(parameters.get("temperature", 0.6))
|
65 |
-
top_p = float(parameters.get("top_p", 0.95))
|
66 |
-
max_new_tokens = int(parameters.get("max_new_tokens", 1200))
|
67 |
-
repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
|
68 |
-
|
69 |
-
# Format prompt with voice
|
70 |
-
prompt = f"{voice}: {text}"
|
71 |
-
|
72 |
-
# Tokenize
|
73 |
-
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
74 |
-
|
75 |
-
# Add special tokens
|
76 |
-
modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1)
|
77 |
-
|
78 |
-
# No need for padding as we're processing a single sequence
|
79 |
-
input_ids = modified_input_ids.to(self.device)
|
80 |
-
attention_mask = torch.ones_like(input_ids)
|
81 |
-
|
82 |
-
# Forward pass through the model
|
83 |
-
generated_ids = self.model.generate(
|
84 |
-
input_ids=input_ids,
|
85 |
-
attention_mask=attention_mask,
|
86 |
-
max_new_tokens=max_new_tokens,
|
87 |
-
do_sample=True,
|
88 |
-
temperature=temperature,
|
89 |
-
top_p=top_p,
|
90 |
-
repetition_penalty=repetition_penalty,
|
91 |
-
num_return_sequences=1,
|
92 |
-
eos_token_id=self.end_audio_token,
|
93 |
-
)
|
94 |
-
|
95 |
-
return generated_ids
|
96 |
-
# # Code from your original handler, but with some fixes
|
97 |
-
# import os
|
98 |
# import torch
|
99 |
# import numpy as np
|
100 |
-
# from transformers import AutoModelForCausalLM, AutoTokenizer
|
101 |
-
# from snac import SNAC
|
102 |
-
# import logging
|
103 |
|
104 |
-
#
|
105 |
-
# logging.basicConfig(level=logging.INFO)
|
106 |
-
# logger = logging.getLogger(__name__)
|
107 |
|
108 |
-
# class EndpointHandler:
|
109 |
# def __init__(self, path=""):
|
110 |
-
|
111 |
-
# # Load the
|
112 |
-
# self.model_name = "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit"
|
113 |
# self.model = AutoModelForCausalLM.from_pretrained(
|
114 |
-
#
|
115 |
# torch_dtype=torch.bfloat16
|
116 |
# )
|
117 |
-
|
118 |
-
|
|
|
119 |
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
120 |
# self.model.to(self.device)
|
121 |
-
|
122 |
-
|
123 |
-
# # Load tokenizer
|
124 |
-
# self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
125 |
-
# logger.info("Tokenizer loaded")
|
126 |
-
|
127 |
-
# # Load SNAC model for audio decoding
|
128 |
-
# try:
|
129 |
-
# self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
|
130 |
-
# self.snac_model.to(self.device)
|
131 |
-
# logger.info("SNAC model loaded")
|
132 |
-
# except Exception as e:
|
133 |
-
# logger.error(f"Error loading SNAC: {str(e)}")
|
134 |
-
# raise
|
135 |
-
|
136 |
# # Special tokens
|
137 |
# self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
|
138 |
# self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
|
139 |
# self.padding_token = 128263
|
140 |
# self.start_audio_token = 128257 # Start of Audio token
|
141 |
# self.end_audio_token = 128258 # End of Audio token
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
# def preprocess(self, data):
|
146 |
# """
|
147 |
-
#
|
148 |
# """
|
149 |
-
# logger.info(f"Preprocessing data: {type(data)}")
|
150 |
|
151 |
-
# #
|
152 |
-
# if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"):
|
153 |
-
# logger.info("Health check detected")
|
154 |
-
# return {"health_check": True}
|
155 |
-
|
156 |
-
# # HF Inference API format: 'inputs' is the text, 'parameters' contains the config
|
157 |
# if isinstance(data, dict) and "inputs" in data:
|
158 |
-
# # Standard HF format
|
159 |
# text = data["inputs"]
|
160 |
# parameters = data.get("parameters", {})
|
161 |
# else:
|
162 |
-
# # Direct access (fallback)
|
163 |
# text = data
|
164 |
# parameters = {}
|
165 |
|
@@ -169,197 +65,299 @@ class EndpointHandler():
|
|
169 |
# top_p = float(parameters.get("top_p", 0.95))
|
170 |
# max_new_tokens = int(parameters.get("max_new_tokens", 1200))
|
171 |
# repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
|
172 |
-
|
173 |
# # Format prompt with voice
|
174 |
# prompt = f"{voice}: {text}"
|
175 |
-
|
176 |
-
|
177 |
# # Tokenize
|
178 |
# input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
179 |
-
|
180 |
# # Add special tokens
|
181 |
# modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1)
|
182 |
-
|
183 |
# # No need for padding as we're processing a single sequence
|
184 |
# input_ids = modified_input_ids.to(self.device)
|
185 |
# attention_mask = torch.ones_like(input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
-
# return
|
188 |
-
#
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
#
|
202 |
-
|
203 |
-
|
204 |
|
205 |
-
#
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
|
213 |
-
|
214 |
|
215 |
-
#
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
|
229 |
-
|
230 |
-
|
231 |
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
#
|
237 |
-
|
238 |
-
|
239 |
|
240 |
-
|
241 |
|
242 |
-
#
|
243 |
-
|
244 |
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
|
253 |
-
#
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
|
259 |
-
#
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
#
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
|
269 |
-
#
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
|
283 |
-
#
|
284 |
-
|
285 |
|
286 |
-
#
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
|
291 |
-
#
|
292 |
-
|
293 |
|
294 |
-
#
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
|
303 |
-
#
|
304 |
-
|
305 |
-
|
306 |
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
|
338 |
-
#
|
339 |
-
|
340 |
-
|
341 |
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
# # logger.info(f"Received request: {type(data)}")
|
349 |
|
350 |
-
#
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
# # return {"error": str(e)}
|
365 |
|
|
|
19 |
Create a `handler.py` file with your custom handler code:
|
20 |
"""
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
# import torch
|
23 |
# import numpy as np
|
|
|
|
|
|
|
24 |
|
25 |
+
# from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
26 |
|
27 |
+
# class EndpointHandler():
|
28 |
# def __init__(self, path=""):
|
29 |
+
|
30 |
+
# # Load the models and tokenizer
|
|
|
31 |
# self.model = AutoModelForCausalLM.from_pretrained(
|
32 |
+
# "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit",
|
33 |
# torch_dtype=torch.bfloat16
|
34 |
# )
|
35 |
+
# self.tokenizer = AutoTokenizer.from_pretrained("hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit")
|
36 |
+
|
37 |
+
# # Move to devices
|
38 |
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
39 |
# self.model.to(self.device)
|
40 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
# # Special tokens
|
42 |
# self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
|
43 |
# self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
|
44 |
# self.padding_token = 128263
|
45 |
# self.start_audio_token = 128257 # Start of Audio token
|
46 |
# self.end_audio_token = 128258 # End of Audio token
|
47 |
+
|
48 |
+
|
49 |
+
# def __call__(self, data):
|
|
|
50 |
# """
|
51 |
+
# Main entry point for the handler
|
52 |
# """
|
|
|
53 |
|
54 |
+
# # Preprocess input
|
|
|
|
|
|
|
|
|
|
|
55 |
# if isinstance(data, dict) and "inputs" in data:
|
|
|
56 |
# text = data["inputs"]
|
57 |
# parameters = data.get("parameters", {})
|
58 |
# else:
|
|
|
59 |
# text = data
|
60 |
# parameters = {}
|
61 |
|
|
|
65 |
# top_p = float(parameters.get("top_p", 0.95))
|
66 |
# max_new_tokens = int(parameters.get("max_new_tokens", 1200))
|
67 |
# repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
|
68 |
+
|
69 |
# # Format prompt with voice
|
70 |
# prompt = f"{voice}: {text}"
|
71 |
+
|
|
|
72 |
# # Tokenize
|
73 |
# input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
74 |
+
|
75 |
# # Add special tokens
|
76 |
# modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1)
|
77 |
+
|
78 |
# # No need for padding as we're processing a single sequence
|
79 |
# input_ids = modified_input_ids.to(self.device)
|
80 |
# attention_mask = torch.ones_like(input_ids)
|
81 |
+
|
82 |
+
# # Forward pass through the model
|
83 |
+
# generated_ids = self.model.generate(
|
84 |
+
# input_ids=input_ids,
|
85 |
+
# attention_mask=attention_mask,
|
86 |
+
# max_new_tokens=max_new_tokens,
|
87 |
+
# do_sample=True,
|
88 |
+
# temperature=temperature,
|
89 |
+
# top_p=top_p,
|
90 |
+
# repetition_penalty=repetition_penalty,
|
91 |
+
# num_return_sequences=1,
|
92 |
+
# eos_token_id=self.end_audio_token,
|
93 |
+
# )
|
94 |
|
95 |
+
# return generated_ids
|
96 |
+
# Code from your original handler, but with some fixes
|
97 |
+
import os
|
98 |
+
import torch
|
99 |
+
import numpy as np
|
100 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
101 |
+
from snac import SNAC
|
102 |
+
import logging
|
103 |
+
|
104 |
+
# Set up logging
|
105 |
+
logging.basicConfig(level=logging.INFO)
|
106 |
+
logger = logging.getLogger(__name__)
|
107 |
+
|
108 |
+
class EndpointHandler:
|
109 |
+
def __init__(self, path=""):
|
110 |
+
logger.info("Initializing Orpheus TTS handler")
|
111 |
+
# Load the Orpheus model and tokenizer
|
112 |
+
self.model_name = "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit"
|
113 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
114 |
+
self.model_name,
|
115 |
+
torch_dtype=torch.bfloat16
|
116 |
+
)
|
117 |
+
|
118 |
+
# Move model to GPU if available
|
119 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
120 |
+
self.model.to(self.device)
|
121 |
+
logger.info(f"Model loaded on {self.device}")
|
122 |
+
|
123 |
+
# Load tokenizer
|
124 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
125 |
+
logger.info("Tokenizer loaded")
|
126 |
+
|
127 |
+
# Load SNAC model for audio decoding
|
128 |
+
try:
|
129 |
+
self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
|
130 |
+
self.snac_model.to(self.device)
|
131 |
+
logger.info("SNAC model loaded")
|
132 |
+
except Exception as e:
|
133 |
+
logger.error(f"Error loading SNAC: {str(e)}")
|
134 |
+
raise
|
135 |
+
|
136 |
+
# Special tokens
|
137 |
+
self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
|
138 |
+
self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
|
139 |
+
self.padding_token = 128263
|
140 |
+
self.start_audio_token = 128257 # Start of Audio token
|
141 |
+
self.end_audio_token = 128258 # End of Audio token
|
142 |
+
|
143 |
+
logger.info("Handler initialization complete")
|
144 |
+
|
145 |
+
def preprocess(self, data):
|
146 |
+
"""
|
147 |
+
Preprocess input data before inference
|
148 |
+
"""
|
149 |
+
logger.info(f"Preprocessing data: {type(data)}")
|
150 |
+
|
151 |
+
# Handle health check
|
152 |
+
if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"):
|
153 |
+
logger.info("Health check detected")
|
154 |
+
return {"health_check": True}
|
155 |
+
|
156 |
+
# HF Inference API format: 'inputs' is the text, 'parameters' contains the config
|
157 |
+
if isinstance(data, dict) and "inputs" in data:
|
158 |
+
# Standard HF format
|
159 |
+
text = data["inputs"]
|
160 |
+
parameters = data.get("parameters", {})
|
161 |
+
else:
|
162 |
+
# Direct access (fallback)
|
163 |
+
text = data
|
164 |
+
parameters = {}
|
165 |
+
|
166 |
+
# Extract parameters from request
|
167 |
+
voice = parameters.get("voice", "tara")
|
168 |
+
temperature = float(parameters.get("temperature", 0.6))
|
169 |
+
top_p = float(parameters.get("top_p", 0.95))
|
170 |
+
max_new_tokens = int(parameters.get("max_new_tokens", 1200))
|
171 |
+
repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
|
172 |
+
|
173 |
+
# Format prompt with voice
|
174 |
+
prompt = f"{voice}: {text}"
|
175 |
+
logger.info(f"Formatted prompt with voice {voice}")
|
176 |
+
|
177 |
+
# Tokenize
|
178 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
179 |
+
|
180 |
+
# Add special tokens
|
181 |
+
modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1)
|
182 |
+
|
183 |
+
# No need for padding as we're processing a single sequence
|
184 |
+
input_ids = modified_input_ids.to(self.device)
|
185 |
+
attention_mask = torch.ones_like(input_ids)
|
186 |
+
|
187 |
+
return {
|
188 |
+
"input_ids": input_ids,
|
189 |
+
"attention_mask": attention_mask,
|
190 |
+
"temperature": temperature,
|
191 |
+
"top_p": top_p,
|
192 |
+
"max_new_tokens": max_new_tokens,
|
193 |
+
"repetition_penalty": repetition_penalty,
|
194 |
+
"health_check": False
|
195 |
+
}
|
196 |
|
197 |
+
def inference(self, inputs):
|
198 |
+
"""
|
199 |
+
Run model inference on the preprocessed inputs
|
200 |
+
"""
|
201 |
+
# Handle health check
|
202 |
+
if inputs.get("health_check", False):
|
203 |
+
return {"status": "ok"}
|
204 |
|
205 |
+
# Extract parameters
|
206 |
+
input_ids = inputs["input_ids"]
|
207 |
+
attention_mask = inputs["attention_mask"]
|
208 |
+
temperature = inputs["temperature"]
|
209 |
+
top_p = inputs["top_p"]
|
210 |
+
max_new_tokens = inputs["max_new_tokens"]
|
211 |
+
repetition_penalty = inputs["repetition_penalty"]
|
212 |
|
213 |
+
logger.info(f"Running inference with max_new_tokens={max_new_tokens}")
|
214 |
|
215 |
+
# Generate output tokens
|
216 |
+
with torch.no_grad():
|
217 |
+
generated_ids = self.model.generate(
|
218 |
+
input_ids=input_ids,
|
219 |
+
attention_mask=attention_mask,
|
220 |
+
max_new_tokens=max_new_tokens,
|
221 |
+
do_sample=True,
|
222 |
+
temperature=temperature,
|
223 |
+
top_p=top_p,
|
224 |
+
repetition_penalty=repetition_penalty,
|
225 |
+
num_return_sequences=1,
|
226 |
+
eos_token_id=self.end_audio_token,
|
227 |
+
)
|
228 |
|
229 |
+
logger.info(f"Generation complete, output shape: {generated_ids.shape}")
|
230 |
+
return generated_ids
|
231 |
|
232 |
+
def postprocess(self, generated_ids):
|
233 |
+
"""
|
234 |
+
Process generated tokens into audio
|
235 |
+
"""
|
236 |
+
# Handle health check response
|
237 |
+
if isinstance(generated_ids, dict) and "status" in generated_ids:
|
238 |
+
return generated_ids
|
239 |
|
240 |
+
logger.info("Postprocessing generated tokens")
|
241 |
|
242 |
+
# Find Start of Audio token
|
243 |
+
token_indices = (generated_ids == self.start_audio_token).nonzero(as_tuple=True)
|
244 |
|
245 |
+
if len(token_indices[1]) > 0:
|
246 |
+
last_occurrence_idx = token_indices[1][-1].item()
|
247 |
+
cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
|
248 |
+
logger.info(f"Found start audio token at position {last_occurrence_idx}")
|
249 |
+
else:
|
250 |
+
cropped_tensor = generated_ids
|
251 |
+
logger.warning("No start audio token found")
|
252 |
|
253 |
+
# Remove End of Audio tokens
|
254 |
+
processed_rows = []
|
255 |
+
for row in cropped_tensor:
|
256 |
+
masked_row = row[row != self.end_audio_token]
|
257 |
+
processed_rows.append(masked_row)
|
258 |
|
259 |
+
# Prepare audio codes
|
260 |
+
code_lists = []
|
261 |
+
for row in processed_rows:
|
262 |
+
row_length = row.size(0)
|
263 |
+
# Ensure length is multiple of 7 for SNAC
|
264 |
+
new_length = (row_length // 7) * 7
|
265 |
+
trimmed_row = row[:new_length]
|
266 |
+
trimmed_row = [t.item() - 128266 for t in trimmed_row] # Adjust token values
|
267 |
+
code_lists.append(trimmed_row)
|
268 |
|
269 |
+
# Generate audio from codes
|
270 |
+
audio_samples = []
|
271 |
+
for code_list in code_lists:
|
272 |
+
logger.info(f"Processing code list of length {len(code_list)}")
|
273 |
+
if len(code_list) > 0:
|
274 |
+
audio = self.redistribute_codes(code_list)
|
275 |
+
audio_samples.append(audio)
|
276 |
+
else:
|
277 |
+
logger.warning("Empty code list, no audio to generate")
|
278 |
|
279 |
+
if not audio_samples:
|
280 |
+
logger.error("No audio samples generated")
|
281 |
+
return {"error": "No audio samples generated"}
|
282 |
|
283 |
+
# Return first (and only) audio sample
|
284 |
+
audio_sample = audio_samples[0].detach().squeeze().cpu().numpy()
|
285 |
|
286 |
+
# Convert to base64 for transmission
|
287 |
+
import base64
|
288 |
+
import io
|
289 |
+
import wave
|
290 |
|
291 |
+
# Convert float32 array to int16 for WAV format
|
292 |
+
audio_int16 = (audio_sample * 32767).astype(np.int16)
|
293 |
|
294 |
+
# Create WAV in memory
|
295 |
+
with io.BytesIO() as wav_io:
|
296 |
+
with wave.open(wav_io, 'wb') as wav_file:
|
297 |
+
wav_file.setnchannels(1) # Mono
|
298 |
+
wav_file.setsampwidth(2) # 16-bit
|
299 |
+
wav_file.setframerate(24000) # 24kHz
|
300 |
+
wav_file.writeframes(audio_int16.tobytes())
|
301 |
+
wav_data = wav_io.getvalue()
|
302 |
|
303 |
+
# Encode as base64
|
304 |
+
audio_b64 = base64.b64encode(wav_data).decode('utf-8')
|
305 |
+
logger.info(f"Audio encoded as base64, length: {len(audio_b64)}")
|
306 |
|
307 |
+
return {
|
308 |
+
"generated_ids": generated_ids.tolist(), #OOO 05102025
|
309 |
+
"audio_b64": audio_b64,
|
310 |
+
"sample_rate": 24000
|
311 |
+
}
|
312 |
|
313 |
+
def redistribute_codes(self, code_list):
|
314 |
+
"""
|
315 |
+
Reorganize tokens for SNAC decoding
|
316 |
+
"""
|
317 |
+
layer_1 = [] # Coarsest layer
|
318 |
+
layer_2 = [] # Intermediate layer
|
319 |
+
layer_3 = [] # Finest layer
|
320 |
|
321 |
+
num_groups = len(code_list) // 7
|
322 |
+
for i in range(num_groups):
|
323 |
+
idx = 7 * i
|
324 |
+
layer_1.append(code_list[idx])
|
325 |
+
layer_2.append(code_list[idx + 1] - 4096)
|
326 |
+
layer_3.append(code_list[idx + 2] - (2 * 4096))
|
327 |
+
layer_3.append(code_list[idx + 3] - (3 * 4096))
|
328 |
+
layer_2.append(code_list[idx + 4] - (4 * 4096))
|
329 |
+
layer_3.append(code_list[idx + 5] - (5 * 4096))
|
330 |
+
layer_3.append(code_list[idx + 6] - (6 * 4096))
|
331 |
|
332 |
+
codes = [
|
333 |
+
torch.tensor(layer_1).unsqueeze(0).to(self.device),
|
334 |
+
torch.tensor(layer_2).unsqueeze(0).to(self.device),
|
335 |
+
torch.tensor(layer_3).unsqueeze(0).to(self.device)
|
336 |
+
]
|
337 |
|
338 |
+
# Decode audio
|
339 |
+
audio_hat = self.snac_model.decode(codes)
|
340 |
+
return audio_hat
|
341 |
|
342 |
+
def __call__(self, data):
|
343 |
+
"""
|
344 |
+
Main entry point for the handler
|
345 |
+
"""
|
346 |
+
try:
|
347 |
+
logger.info(f"Received request: {type(data)}")
|
|
|
348 |
|
349 |
+
# Check if we need to handle the health check route
|
350 |
+
if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"):
|
351 |
+
logger.info("Processing health check request")
|
352 |
+
return {"status": "ok"}
|
353 |
|
354 |
+
preprocessed_inputs = self.preprocess(data)
|
355 |
+
model_outputs = self.inference(preprocessed_inputs)
|
356 |
+
response = self.postprocess(model_outputs)
|
357 |
+
return response
|
358 |
+
except Exception as e:
|
359 |
+
logger.error(f"Error processing request: {str(e)}")
|
360 |
+
import traceback
|
361 |
+
logger.error(traceback.format_exc())
|
362 |
+
return {"error": str(e)}
|
|
|
363 |
|