hypaai commited on
Commit
4805443
·
verified ·
1 Parent(s): a07f15b

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +273 -275
handler.py CHANGED
@@ -19,147 +19,43 @@ rather than the custom audio generation handler you've defined.
19
  Create a `handler.py` file with your custom handler code:
20
  """
21
 
22
- import torch
23
- import numpy as np
24
-
25
- from transformers import AutoModelForCausalLM, AutoTokenizer
26
-
27
- class EndpointHandler():
28
- def __init__(self, path=""):
29
-
30
- # Load the models and tokenizer
31
- self.model = AutoModelForCausalLM.from_pretrained(
32
- "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit",
33
- torch_dtype=torch.bfloat16
34
- )
35
- self.tokenizer = AutoTokenizer.from_pretrained("hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit")
36
-
37
- # Move to devices
38
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
39
- self.model.to(self.device)
40
-
41
- # Special tokens
42
- self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
43
- self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
44
- self.padding_token = 128263
45
- self.start_audio_token = 128257 # Start of Audio token
46
- self.end_audio_token = 128258 # End of Audio token
47
-
48
-
49
- def __call__(self, data):
50
- """
51
- Main entry point for the handler
52
- """
53
-
54
- # Preprocess input
55
- if isinstance(data, dict) and "inputs" in data:
56
- text = data["inputs"]
57
- parameters = data.get("parameters", {})
58
- else:
59
- text = data
60
- parameters = {}
61
-
62
- # Extract parameters from request
63
- voice = parameters.get("voice", "tara")
64
- temperature = float(parameters.get("temperature", 0.6))
65
- top_p = float(parameters.get("top_p", 0.95))
66
- max_new_tokens = int(parameters.get("max_new_tokens", 1200))
67
- repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
68
-
69
- # Format prompt with voice
70
- prompt = f"{voice}: {text}"
71
-
72
- # Tokenize
73
- input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
74
-
75
- # Add special tokens
76
- modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1)
77
-
78
- # No need for padding as we're processing a single sequence
79
- input_ids = modified_input_ids.to(self.device)
80
- attention_mask = torch.ones_like(input_ids)
81
-
82
- # Forward pass through the model
83
- generated_ids = self.model.generate(
84
- input_ids=input_ids,
85
- attention_mask=attention_mask,
86
- max_new_tokens=max_new_tokens,
87
- do_sample=True,
88
- temperature=temperature,
89
- top_p=top_p,
90
- repetition_penalty=repetition_penalty,
91
- num_return_sequences=1,
92
- eos_token_id=self.end_audio_token,
93
- )
94
-
95
- return generated_ids
96
- # # Code from your original handler, but with some fixes
97
- # import os
98
  # import torch
99
  # import numpy as np
100
- # from transformers import AutoModelForCausalLM, AutoTokenizer
101
- # from snac import SNAC
102
- # import logging
103
 
104
- # # Set up logging
105
- # logging.basicConfig(level=logging.INFO)
106
- # logger = logging.getLogger(__name__)
107
 
108
- # class EndpointHandler:
109
  # def __init__(self, path=""):
110
- # logger.info("Initializing Orpheus TTS handler")
111
- # # Load the Orpheus model and tokenizer
112
- # self.model_name = "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit"
113
  # self.model = AutoModelForCausalLM.from_pretrained(
114
- # self.model_name,
115
  # torch_dtype=torch.bfloat16
116
  # )
117
-
118
- # # Move model to GPU if available
 
119
  # self.device = "cuda" if torch.cuda.is_available() else "cpu"
120
  # self.model.to(self.device)
121
- # logger.info(f"Model loaded on {self.device}")
122
-
123
- # # Load tokenizer
124
- # self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
125
- # logger.info("Tokenizer loaded")
126
-
127
- # # Load SNAC model for audio decoding
128
- # try:
129
- # self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
130
- # self.snac_model.to(self.device)
131
- # logger.info("SNAC model loaded")
132
- # except Exception as e:
133
- # logger.error(f"Error loading SNAC: {str(e)}")
134
- # raise
135
-
136
  # # Special tokens
137
  # self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
138
  # self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
139
  # self.padding_token = 128263
140
  # self.start_audio_token = 128257 # Start of Audio token
141
  # self.end_audio_token = 128258 # End of Audio token
142
-
143
- # logger.info("Handler initialization complete")
144
-
145
- # def preprocess(self, data):
146
  # """
147
- # Preprocess input data before inference
148
  # """
149
- # logger.info(f"Preprocessing data: {type(data)}")
150
 
151
- # # Handle health check
152
- # if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"):
153
- # logger.info("Health check detected")
154
- # return {"health_check": True}
155
-
156
- # # HF Inference API format: 'inputs' is the text, 'parameters' contains the config
157
  # if isinstance(data, dict) and "inputs" in data:
158
- # # Standard HF format
159
  # text = data["inputs"]
160
  # parameters = data.get("parameters", {})
161
  # else:
162
- # # Direct access (fallback)
163
  # text = data
164
  # parameters = {}
165
 
@@ -169,197 +65,299 @@ class EndpointHandler():
169
  # top_p = float(parameters.get("top_p", 0.95))
170
  # max_new_tokens = int(parameters.get("max_new_tokens", 1200))
171
  # repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
172
-
173
  # # Format prompt with voice
174
  # prompt = f"{voice}: {text}"
175
- # logger.info(f"Formatted prompt with voice {voice}")
176
-
177
  # # Tokenize
178
  # input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
179
-
180
  # # Add special tokens
181
  # modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1)
182
-
183
  # # No need for padding as we're processing a single sequence
184
  # input_ids = modified_input_ids.to(self.device)
185
  # attention_mask = torch.ones_like(input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
- # return {
188
- # "input_ids": input_ids,
189
- # "attention_mask": attention_mask,
190
- # "temperature": temperature,
191
- # "top_p": top_p,
192
- # "max_new_tokens": max_new_tokens,
193
- # "repetition_penalty": repetition_penalty,
194
- # "health_check": False
195
- # }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- # def inference(self, inputs):
198
- # """
199
- # Run model inference on the preprocessed inputs
200
- # """
201
- # # Handle health check
202
- # if inputs.get("health_check", False):
203
- # return {"status": "ok"}
204
 
205
- # # Extract parameters
206
- # input_ids = inputs["input_ids"]
207
- # attention_mask = inputs["attention_mask"]
208
- # temperature = inputs["temperature"]
209
- # top_p = inputs["top_p"]
210
- # max_new_tokens = inputs["max_new_tokens"]
211
- # repetition_penalty = inputs["repetition_penalty"]
212
 
213
- # logger.info(f"Running inference with max_new_tokens={max_new_tokens}")
214
 
215
- # # Generate output tokens
216
- # with torch.no_grad():
217
- # generated_ids = self.model.generate(
218
- # input_ids=input_ids,
219
- # attention_mask=attention_mask,
220
- # max_new_tokens=max_new_tokens,
221
- # do_sample=True,
222
- # temperature=temperature,
223
- # top_p=top_p,
224
- # repetition_penalty=repetition_penalty,
225
- # num_return_sequences=1,
226
- # eos_token_id=self.end_audio_token,
227
- # )
228
 
229
- # logger.info(f"Generation complete, output shape: {generated_ids.shape}")
230
- # return generated_ids
231
 
232
- # def postprocess(self, generated_ids):
233
- # """
234
- # Process generated tokens into audio
235
- # """
236
- # # Handle health check response
237
- # if isinstance(generated_ids, dict) and "status" in generated_ids:
238
- # return generated_ids
239
 
240
- # logger.info("Postprocessing generated tokens")
241
 
242
- # # Find Start of Audio token
243
- # token_indices = (generated_ids == self.start_audio_token).nonzero(as_tuple=True)
244
 
245
- # if len(token_indices[1]) > 0:
246
- # last_occurrence_idx = token_indices[1][-1].item()
247
- # cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
248
- # logger.info(f"Found start audio token at position {last_occurrence_idx}")
249
- # else:
250
- # cropped_tensor = generated_ids
251
- # logger.warning("No start audio token found")
252
 
253
- # # Remove End of Audio tokens
254
- # processed_rows = []
255
- # for row in cropped_tensor:
256
- # masked_row = row[row != self.end_audio_token]
257
- # processed_rows.append(masked_row)
258
 
259
- # # Prepare audio codes
260
- # code_lists = []
261
- # for row in processed_rows:
262
- # row_length = row.size(0)
263
- # # Ensure length is multiple of 7 for SNAC
264
- # new_length = (row_length // 7) * 7
265
- # trimmed_row = row[:new_length]
266
- # trimmed_row = [t.item() - 128266 for t in trimmed_row] # Adjust token values
267
- # code_lists.append(trimmed_row)
268
 
269
- # # Generate audio from codes
270
- # audio_samples = []
271
- # for code_list in code_lists:
272
- # logger.info(f"Processing code list of length {len(code_list)}")
273
- # if len(code_list) > 0:
274
- # audio = self.redistribute_codes(code_list)
275
- # audio_samples.append(audio)
276
- # else:
277
- # logger.warning("Empty code list, no audio to generate")
278
 
279
- # if not audio_samples:
280
- # logger.error("No audio samples generated")
281
- # return {"error": "No audio samples generated"}
282
 
283
- # # Return first (and only) audio sample
284
- # audio_sample = audio_samples[0].detach().squeeze().cpu().numpy()
285
 
286
- # # Convert to base64 for transmission
287
- # import base64
288
- # import io
289
- # import wave
290
 
291
- # # Convert float32 array to int16 for WAV format
292
- # audio_int16 = (audio_sample * 32767).astype(np.int16)
293
 
294
- # # Create WAV in memory
295
- # with io.BytesIO() as wav_io:
296
- # with wave.open(wav_io, 'wb') as wav_file:
297
- # wav_file.setnchannels(1) # Mono
298
- # wav_file.setsampwidth(2) # 16-bit
299
- # wav_file.setframerate(24000) # 24kHz
300
- # wav_file.writeframes(audio_int16.tobytes())
301
- # wav_data = wav_io.getvalue()
302
 
303
- # # Encode as base64
304
- # audio_b64 = base64.b64encode(wav_data).decode('utf-8')
305
- # logger.info(f"Audio encoded as base64, length: {len(audio_b64)}")
306
 
307
- # return {
308
- # "generated_ids": generated_ids.tolist(), #OOO 05102025
309
- # "audio_b64": audio_b64,
310
- # "sample_rate": 24000
311
- # }
312
 
313
- # def redistribute_codes(self, code_list):
314
- # """
315
- # Reorganize tokens for SNAC decoding
316
- # """
317
- # layer_1 = [] # Coarsest layer
318
- # layer_2 = [] # Intermediate layer
319
- # layer_3 = [] # Finest layer
320
 
321
- # num_groups = len(code_list) // 7
322
- # for i in range(num_groups):
323
- # idx = 7 * i
324
- # layer_1.append(code_list[idx])
325
- # layer_2.append(code_list[idx + 1] - 4096)
326
- # layer_3.append(code_list[idx + 2] - (2 * 4096))
327
- # layer_3.append(code_list[idx + 3] - (3 * 4096))
328
- # layer_2.append(code_list[idx + 4] - (4 * 4096))
329
- # layer_3.append(code_list[idx + 5] - (5 * 4096))
330
- # layer_3.append(code_list[idx + 6] - (6 * 4096))
331
 
332
- # codes = [
333
- # torch.tensor(layer_1).unsqueeze(0).to(self.device),
334
- # torch.tensor(layer_2).unsqueeze(0).to(self.device),
335
- # torch.tensor(layer_3).unsqueeze(0).to(self.device)
336
- # ]
337
 
338
- # # Decode audio
339
- # audio_hat = self.snac_model.decode(codes)
340
- # return audio_hat
341
 
342
- # def __call__(self, data):
343
- # """
344
- # Main entry point for the handler
345
- # """
346
- # return "Testing the mike, one two"
347
- # # try:
348
- # # logger.info(f"Received request: {type(data)}")
349
 
350
- # # # Check if we need to handle the health check route
351
- # # if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"):
352
- # # logger.info("Processing health check request")
353
- # # return {"status": "ok"}
354
 
355
- # # preprocessed_inputs = self.preprocess(data)
356
- # # model_outputs = self.inference(preprocessed_inputs)
357
- # # response = self.postprocess(model_outputs)
358
- # # response = reponse["generated_ids"] #OOO 05102025
359
- # # return response
360
- # # except Exception as e:
361
- # # logger.error(f"Error processing request: {str(e)}")
362
- # # import traceback
363
- # # logger.error(traceback.format_exc())
364
- # # return {"error": str(e)}
365
 
 
19
  Create a `handler.py` file with your custom handler code:
20
  """
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # import torch
23
  # import numpy as np
 
 
 
24
 
25
+ # from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
26
 
27
+ # class EndpointHandler():
28
  # def __init__(self, path=""):
29
+
30
+ # # Load the models and tokenizer
 
31
  # self.model = AutoModelForCausalLM.from_pretrained(
32
+ # "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit",
33
  # torch_dtype=torch.bfloat16
34
  # )
35
+ # self.tokenizer = AutoTokenizer.from_pretrained("hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit")
36
+
37
+ # # Move to devices
38
  # self.device = "cuda" if torch.cuda.is_available() else "cpu"
39
  # self.model.to(self.device)
40
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # # Special tokens
42
  # self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
43
  # self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
44
  # self.padding_token = 128263
45
  # self.start_audio_token = 128257 # Start of Audio token
46
  # self.end_audio_token = 128258 # End of Audio token
47
+
48
+
49
+ # def __call__(self, data):
 
50
  # """
51
+ # Main entry point for the handler
52
  # """
 
53
 
54
+ # # Preprocess input
 
 
 
 
 
55
  # if isinstance(data, dict) and "inputs" in data:
 
56
  # text = data["inputs"]
57
  # parameters = data.get("parameters", {})
58
  # else:
 
59
  # text = data
60
  # parameters = {}
61
 
 
65
  # top_p = float(parameters.get("top_p", 0.95))
66
  # max_new_tokens = int(parameters.get("max_new_tokens", 1200))
67
  # repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
68
+
69
  # # Format prompt with voice
70
  # prompt = f"{voice}: {text}"
71
+
 
72
  # # Tokenize
73
  # input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
74
+
75
  # # Add special tokens
76
  # modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1)
77
+
78
  # # No need for padding as we're processing a single sequence
79
  # input_ids = modified_input_ids.to(self.device)
80
  # attention_mask = torch.ones_like(input_ids)
81
+
82
+ # # Forward pass through the model
83
+ # generated_ids = self.model.generate(
84
+ # input_ids=input_ids,
85
+ # attention_mask=attention_mask,
86
+ # max_new_tokens=max_new_tokens,
87
+ # do_sample=True,
88
+ # temperature=temperature,
89
+ # top_p=top_p,
90
+ # repetition_penalty=repetition_penalty,
91
+ # num_return_sequences=1,
92
+ # eos_token_id=self.end_audio_token,
93
+ # )
94
 
95
+ # return generated_ids
96
+ # Code from your original handler, but with some fixes
97
+ import os
98
+ import torch
99
+ import numpy as np
100
+ from transformers import AutoModelForCausalLM, AutoTokenizer
101
+ from snac import SNAC
102
+ import logging
103
+
104
+ # Set up logging
105
+ logging.basicConfig(level=logging.INFO)
106
+ logger = logging.getLogger(__name__)
107
+
108
+ class EndpointHandler:
109
+ def __init__(self, path=""):
110
+ logger.info("Initializing Orpheus TTS handler")
111
+ # Load the Orpheus model and tokenizer
112
+ self.model_name = "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit"
113
+ self.model = AutoModelForCausalLM.from_pretrained(
114
+ self.model_name,
115
+ torch_dtype=torch.bfloat16
116
+ )
117
+
118
+ # Move model to GPU if available
119
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
120
+ self.model.to(self.device)
121
+ logger.info(f"Model loaded on {self.device}")
122
+
123
+ # Load tokenizer
124
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
125
+ logger.info("Tokenizer loaded")
126
+
127
+ # Load SNAC model for audio decoding
128
+ try:
129
+ self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
130
+ self.snac_model.to(self.device)
131
+ logger.info("SNAC model loaded")
132
+ except Exception as e:
133
+ logger.error(f"Error loading SNAC: {str(e)}")
134
+ raise
135
+
136
+ # Special tokens
137
+ self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
138
+ self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
139
+ self.padding_token = 128263
140
+ self.start_audio_token = 128257 # Start of Audio token
141
+ self.end_audio_token = 128258 # End of Audio token
142
+
143
+ logger.info("Handler initialization complete")
144
+
145
+ def preprocess(self, data):
146
+ """
147
+ Preprocess input data before inference
148
+ """
149
+ logger.info(f"Preprocessing data: {type(data)}")
150
+
151
+ # Handle health check
152
+ if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"):
153
+ logger.info("Health check detected")
154
+ return {"health_check": True}
155
+
156
+ # HF Inference API format: 'inputs' is the text, 'parameters' contains the config
157
+ if isinstance(data, dict) and "inputs" in data:
158
+ # Standard HF format
159
+ text = data["inputs"]
160
+ parameters = data.get("parameters", {})
161
+ else:
162
+ # Direct access (fallback)
163
+ text = data
164
+ parameters = {}
165
+
166
+ # Extract parameters from request
167
+ voice = parameters.get("voice", "tara")
168
+ temperature = float(parameters.get("temperature", 0.6))
169
+ top_p = float(parameters.get("top_p", 0.95))
170
+ max_new_tokens = int(parameters.get("max_new_tokens", 1200))
171
+ repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
172
+
173
+ # Format prompt with voice
174
+ prompt = f"{voice}: {text}"
175
+ logger.info(f"Formatted prompt with voice {voice}")
176
+
177
+ # Tokenize
178
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
179
+
180
+ # Add special tokens
181
+ modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1)
182
+
183
+ # No need for padding as we're processing a single sequence
184
+ input_ids = modified_input_ids.to(self.device)
185
+ attention_mask = torch.ones_like(input_ids)
186
+
187
+ return {
188
+ "input_ids": input_ids,
189
+ "attention_mask": attention_mask,
190
+ "temperature": temperature,
191
+ "top_p": top_p,
192
+ "max_new_tokens": max_new_tokens,
193
+ "repetition_penalty": repetition_penalty,
194
+ "health_check": False
195
+ }
196
 
197
+ def inference(self, inputs):
198
+ """
199
+ Run model inference on the preprocessed inputs
200
+ """
201
+ # Handle health check
202
+ if inputs.get("health_check", False):
203
+ return {"status": "ok"}
204
 
205
+ # Extract parameters
206
+ input_ids = inputs["input_ids"]
207
+ attention_mask = inputs["attention_mask"]
208
+ temperature = inputs["temperature"]
209
+ top_p = inputs["top_p"]
210
+ max_new_tokens = inputs["max_new_tokens"]
211
+ repetition_penalty = inputs["repetition_penalty"]
212
 
213
+ logger.info(f"Running inference with max_new_tokens={max_new_tokens}")
214
 
215
+ # Generate output tokens
216
+ with torch.no_grad():
217
+ generated_ids = self.model.generate(
218
+ input_ids=input_ids,
219
+ attention_mask=attention_mask,
220
+ max_new_tokens=max_new_tokens,
221
+ do_sample=True,
222
+ temperature=temperature,
223
+ top_p=top_p,
224
+ repetition_penalty=repetition_penalty,
225
+ num_return_sequences=1,
226
+ eos_token_id=self.end_audio_token,
227
+ )
228
 
229
+ logger.info(f"Generation complete, output shape: {generated_ids.shape}")
230
+ return generated_ids
231
 
232
+ def postprocess(self, generated_ids):
233
+ """
234
+ Process generated tokens into audio
235
+ """
236
+ # Handle health check response
237
+ if isinstance(generated_ids, dict) and "status" in generated_ids:
238
+ return generated_ids
239
 
240
+ logger.info("Postprocessing generated tokens")
241
 
242
+ # Find Start of Audio token
243
+ token_indices = (generated_ids == self.start_audio_token).nonzero(as_tuple=True)
244
 
245
+ if len(token_indices[1]) > 0:
246
+ last_occurrence_idx = token_indices[1][-1].item()
247
+ cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
248
+ logger.info(f"Found start audio token at position {last_occurrence_idx}")
249
+ else:
250
+ cropped_tensor = generated_ids
251
+ logger.warning("No start audio token found")
252
 
253
+ # Remove End of Audio tokens
254
+ processed_rows = []
255
+ for row in cropped_tensor:
256
+ masked_row = row[row != self.end_audio_token]
257
+ processed_rows.append(masked_row)
258
 
259
+ # Prepare audio codes
260
+ code_lists = []
261
+ for row in processed_rows:
262
+ row_length = row.size(0)
263
+ # Ensure length is multiple of 7 for SNAC
264
+ new_length = (row_length // 7) * 7
265
+ trimmed_row = row[:new_length]
266
+ trimmed_row = [t.item() - 128266 for t in trimmed_row] # Adjust token values
267
+ code_lists.append(trimmed_row)
268
 
269
+ # Generate audio from codes
270
+ audio_samples = []
271
+ for code_list in code_lists:
272
+ logger.info(f"Processing code list of length {len(code_list)}")
273
+ if len(code_list) > 0:
274
+ audio = self.redistribute_codes(code_list)
275
+ audio_samples.append(audio)
276
+ else:
277
+ logger.warning("Empty code list, no audio to generate")
278
 
279
+ if not audio_samples:
280
+ logger.error("No audio samples generated")
281
+ return {"error": "No audio samples generated"}
282
 
283
+ # Return first (and only) audio sample
284
+ audio_sample = audio_samples[0].detach().squeeze().cpu().numpy()
285
 
286
+ # Convert to base64 for transmission
287
+ import base64
288
+ import io
289
+ import wave
290
 
291
+ # Convert float32 array to int16 for WAV format
292
+ audio_int16 = (audio_sample * 32767).astype(np.int16)
293
 
294
+ # Create WAV in memory
295
+ with io.BytesIO() as wav_io:
296
+ with wave.open(wav_io, 'wb') as wav_file:
297
+ wav_file.setnchannels(1) # Mono
298
+ wav_file.setsampwidth(2) # 16-bit
299
+ wav_file.setframerate(24000) # 24kHz
300
+ wav_file.writeframes(audio_int16.tobytes())
301
+ wav_data = wav_io.getvalue()
302
 
303
+ # Encode as base64
304
+ audio_b64 = base64.b64encode(wav_data).decode('utf-8')
305
+ logger.info(f"Audio encoded as base64, length: {len(audio_b64)}")
306
 
307
+ return {
308
+ "generated_ids": generated_ids.tolist(), #OOO 05102025
309
+ "audio_b64": audio_b64,
310
+ "sample_rate": 24000
311
+ }
312
 
313
+ def redistribute_codes(self, code_list):
314
+ """
315
+ Reorganize tokens for SNAC decoding
316
+ """
317
+ layer_1 = [] # Coarsest layer
318
+ layer_2 = [] # Intermediate layer
319
+ layer_3 = [] # Finest layer
320
 
321
+ num_groups = len(code_list) // 7
322
+ for i in range(num_groups):
323
+ idx = 7 * i
324
+ layer_1.append(code_list[idx])
325
+ layer_2.append(code_list[idx + 1] - 4096)
326
+ layer_3.append(code_list[idx + 2] - (2 * 4096))
327
+ layer_3.append(code_list[idx + 3] - (3 * 4096))
328
+ layer_2.append(code_list[idx + 4] - (4 * 4096))
329
+ layer_3.append(code_list[idx + 5] - (5 * 4096))
330
+ layer_3.append(code_list[idx + 6] - (6 * 4096))
331
 
332
+ codes = [
333
+ torch.tensor(layer_1).unsqueeze(0).to(self.device),
334
+ torch.tensor(layer_2).unsqueeze(0).to(self.device),
335
+ torch.tensor(layer_3).unsqueeze(0).to(self.device)
336
+ ]
337
 
338
+ # Decode audio
339
+ audio_hat = self.snac_model.decode(codes)
340
+ return audio_hat
341
 
342
+ def __call__(self, data):
343
+ """
344
+ Main entry point for the handler
345
+ """
346
+ try:
347
+ logger.info(f"Received request: {type(data)}")
 
348
 
349
+ # Check if we need to handle the health check route
350
+ if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"):
351
+ logger.info("Processing health check request")
352
+ return {"status": "ok"}
353
 
354
+ preprocessed_inputs = self.preprocess(data)
355
+ model_outputs = self.inference(preprocessed_inputs)
356
+ response = self.postprocess(model_outputs)
357
+ return response
358
+ except Exception as e:
359
+ logger.error(f"Error processing request: {str(e)}")
360
+ import traceback
361
+ logger.error(traceback.format_exc())
362
+ return {"error": str(e)}
 
363