Updated handler.py with claude
Browse files- handler.py +186 -193
handler.py
CHANGED
@@ -1,20 +1,176 @@
|
|
|
|
1 |
import torch
|
2 |
import numpy as np
|
3 |
-
import soundfile as sf
|
4 |
-
import io
|
5 |
-
import os
|
6 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
7 |
-
from snac import SNAC
|
8 |
-
|
9 |
-
# --- Helper Function (can be outside or inside the class) ---
|
10 |
-
def redistribute_codes_static(code_list):
|
11 |
-
""" Reorganizes the flattened token list into three separate layers for SNAC. """
|
12 |
-
layer_1, layer_2, layer_3 = [], [], []
|
13 |
-
num_groups = len(code_list) // 7 # Use floor division
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
layer_1.append(code_list[idx])
|
19 |
layer_2.append(code_list[idx + 1] - 4096)
|
20 |
layer_3.append(code_list[idx + 2] - (2 * 4096))
|
@@ -22,185 +178,22 @@ def redistribute_codes_static(code_list):
|
|
22 |
layer_2.append(code_list[idx + 4] - (4 * 4096))
|
23 |
layer_3.append(code_list[idx + 5] - (5 * 4096))
|
24 |
layer_3.append(code_list[idx + 6] - (6 * 4096))
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
if layer_2: codes[1] = torch.tensor(layer_2).unsqueeze(0).long()
|
38 |
-
if layer_3: codes[2] = torch.tensor(layer_3).unsqueeze(0).long()
|
39 |
-
|
40 |
-
return codes
|
41 |
-
|
42 |
-
# --- Endpoint Handler Class ---
|
43 |
-
class EndpointHandler():
|
44 |
-
def __init__(self, path=""):
|
45 |
-
"""
|
46 |
-
Initializes the handler. Loads both Orpheus LLM and SNAC Vocoder.
|
47 |
-
'path' points to the directory containing the Orpheus model files specified in the endpoint config.
|
48 |
-
"""
|
49 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
50 |
-
print(f"Using device: {self.device}")
|
51 |
-
|
52 |
-
# Define Model Names/Paths
|
53 |
-
# Orpheus LLM path is determined by the endpoint configuration ('path' variable)
|
54 |
-
orpheus_model_path = path if path else "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit"
|
55 |
-
snac_model_name = "hubertsiuzdak/snac_24khz"
|
56 |
-
|
57 |
-
# Define Special Token IDs (matching your script)
|
58 |
-
self.start_human_token_id = 128259
|
59 |
-
self.end_text_token_id = 128009
|
60 |
-
self.end_human_token_id = 128260
|
61 |
-
# self.padding_token_id = 128263 # Not needed for single sequence generation
|
62 |
-
self.start_audio_token_id = 128257
|
63 |
-
self.end_audio_token_id = 128258
|
64 |
-
self.audio_code_offset = 128266
|
65 |
-
|
66 |
-
# Define sampling rate
|
67 |
-
self.sampling_rate = 24000
|
68 |
-
|
69 |
-
try:
|
70 |
-
# Load Orpheus LLM and Tokenizer
|
71 |
-
print(f"Loading Orpheus tokenizer from: {orpheus_model_path}")
|
72 |
-
self.tokenizer = AutoTokenizer.from_pretrained(orpheus_model_path)
|
73 |
-
print(f"Loading Orpheus model from: {orpheus_model_path}")
|
74 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
75 |
-
orpheus_model_path,
|
76 |
-
torch_dtype=torch.bfloat16 # Use bfloat16 as in your script
|
77 |
-
)
|
78 |
-
self.model.to(self.device)
|
79 |
-
self.model.eval() # Set model to evaluation mode
|
80 |
-
print("Orpheus model and tokenizer loaded successfully.")
|
81 |
-
|
82 |
-
# Load SNAC Vocoder
|
83 |
-
print(f"Loading SNAC model from: {snac_model_name}")
|
84 |
-
self.snac_model = SNAC.from_pretrained(snac_model_name)
|
85 |
-
self.snac_model.to(self.device) # Move SNAC to the same device
|
86 |
-
self.snac_model.eval() # Set model to evaluation mode
|
87 |
-
print("SNAC model loaded successfully.")
|
88 |
-
|
89 |
-
except Exception as e:
|
90 |
-
print(f"Error during model loading: {e}")
|
91 |
-
raise RuntimeError(f"Failed to load models.", e)
|
92 |
-
|
93 |
-
def __call__(self, data: dict) -> bytes:
|
94 |
"""
|
95 |
-
|
96 |
-
Expects data['inputs'] (text) and optionally data['parameters']
|
97 |
"""
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
raise ValueError("Missing 'inputs' key in request data")
|
103 |
-
|
104 |
-
parameters = data.pop("parameters", {})
|
105 |
-
# Default voice if not provided
|
106 |
-
voice = parameters.get("voice", "Eniola")
|
107 |
-
# Default generation parameters (merge with provided ones)
|
108 |
-
gen_params = {
|
109 |
-
"max_new_tokens": 1200,
|
110 |
-
"do_sample": True,
|
111 |
-
"temperature": 0.6,
|
112 |
-
"top_p": 0.95,
|
113 |
-
"repetition_penalty": 1.1,
|
114 |
-
"num_return_sequences": 1,
|
115 |
-
"eos_token_id": self.end_audio_token_id,
|
116 |
-
**parameters # Overwrite defaults with user params
|
117 |
-
}
|
118 |
-
# Remove non-generate params if they were passed
|
119 |
-
gen_params.pop("voice", None)
|
120 |
-
|
121 |
-
print(f"Received request: text='{text[:50]}...', voice='{voice}', params={gen_params}")
|
122 |
-
|
123 |
-
# --- Preprocess Text ---
|
124 |
-
prompt = f"{voice}: {text}"
|
125 |
-
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
126 |
-
|
127 |
-
# Add special tokens: SOH + Input Tokens + EOT + EOH
|
128 |
-
start_token_tensor = torch.tensor([[self.start_human_token_id]], dtype=torch.int64)
|
129 |
-
end_tokens_tensor = torch.tensor([[self.end_text_token_id, self.end_human_token_id]], dtype=torch.int64)
|
130 |
-
processed_input_ids = torch.cat([start_token_tensor, input_ids, end_tokens_tensor], dim=1).to(self.device)
|
131 |
-
|
132 |
-
# Create attention mask (all ones for single, unpadded sequence)
|
133 |
-
attention_mask = torch.ones_like(processed_input_ids).to(self.device)
|
134 |
-
|
135 |
-
print(f"Processed input shape: {processed_input_ids.shape}")
|
136 |
-
|
137 |
-
# --- Generate Audio Codes (LLM Inference) ---
|
138 |
-
with torch.no_grad():
|
139 |
-
generated_ids = self.model.generate(
|
140 |
-
input_ids=processed_input_ids,
|
141 |
-
attention_mask=attention_mask,
|
142 |
-
**gen_params
|
143 |
-
)
|
144 |
-
print(f"Generated IDs shape: {generated_ids.shape}")
|
145 |
-
|
146 |
-
# --- Process Generated Tokens (Extract Audio Codes) ---
|
147 |
-
# Find the last Start of Audio token
|
148 |
-
soa_indices = (generated_ids[0] == self.start_audio_token_id).nonzero(as_tuple=True)[0]
|
149 |
-
if len(soa_indices) == 0:
|
150 |
-
print("Warning: Start of Audio token (128257) not found in generated sequence!")
|
151 |
-
# Handle this case: maybe return error, or try processing from start?
|
152 |
-
# For now, let's assume it might still contain codes and try processing all generated *new* tokens
|
153 |
-
start_idx = processed_input_ids.shape[1] # Start after the input prompt
|
154 |
-
else:
|
155 |
-
start_idx = soa_indices[-1].item() + 1 # Start after the last SOA token
|
156 |
-
|
157 |
-
# Extract potential audio codes (after last SOA or after input)
|
158 |
-
cropped_tokens = generated_ids[0, start_idx:]
|
159 |
-
|
160 |
-
# Remove End of Audio tokens
|
161 |
-
audio_codes_raw = cropped_tokens[cropped_tokens != self.end_audio_token_id]
|
162 |
-
print(f"Extracted raw audio codes count: {len(audio_codes_raw)}")
|
163 |
-
|
164 |
-
if len(audio_codes_raw) == 0:
|
165 |
-
raise ValueError("No audio codes generated or extracted after processing.")
|
166 |
-
|
167 |
-
# --- Prepare Codes for SNAC Vocoder ---
|
168 |
-
# Adjust token values
|
169 |
-
adjusted_codes = [t.item() - self.audio_code_offset for t in audio_codes_raw]
|
170 |
-
|
171 |
-
# Trim to multiple of 7
|
172 |
-
num_codes = len(adjusted_codes)
|
173 |
-
valid_length = (num_codes // 7) * 7
|
174 |
-
if valid_length == 0:
|
175 |
-
raise ValueError(f"Not enough audio codes ({num_codes}) to form a multiple of 7 after processing.")
|
176 |
-
trimmed_codes = adjusted_codes[:valid_length]
|
177 |
-
print(f"Trimmed adjusted audio codes count: {len(trimmed_codes)}")
|
178 |
-
|
179 |
-
# --- Redistribute Codes ---
|
180 |
-
# Use static method or instance method, ensure tensors are on correct device
|
181 |
-
snac_input_codes = redistribute_codes_static(trimmed_codes)
|
182 |
-
snac_input_codes = [layer.to(self.device) for layer in snac_input_codes]
|
183 |
-
|
184 |
-
# --- Decode Audio (SNAC Inference) ---
|
185 |
-
print("Decoding audio with SNAC...")
|
186 |
-
with torch.no_grad():
|
187 |
-
audio_hat = self.snac_model.decode(snac_input_codes)
|
188 |
-
print(f"Decoded audio tensor shape: {audio_hat.shape}") # Should be [1, 1, num_samples]
|
189 |
-
|
190 |
-
# --- Postprocess Audio ---
|
191 |
-
# Move to CPU, remove batch/channel dims, convert to numpy
|
192 |
-
audio_waveform = audio_hat.detach().squeeze().cpu().numpy()
|
193 |
-
|
194 |
-
# --- Convert to WAV Bytes ---
|
195 |
-
buffer = io.BytesIO()
|
196 |
-
sf.write(buffer, audio_waveform, self.sampling_rate, format='WAV')
|
197 |
-
buffer.seek(0)
|
198 |
-
wav_bytes = buffer.read()
|
199 |
-
|
200 |
-
print(f"Generated {len(wav_bytes)} bytes of WAV audio.")
|
201 |
-
return wav_bytes
|
202 |
-
|
203 |
-
except Exception as e:
|
204 |
-
print(f"Error during inference call: {e}")
|
205 |
-
# Re-raise for endpoint framework
|
206 |
-
raise RuntimeError(f"Inference failed: {e}")
|
|
|
1 |
+
import os
|
2 |
import torch
|
3 |
import numpy as np
|
|
|
|
|
|
|
4 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
5 |
+
from snac import SNAC
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
+
class EndpointHandler:
|
8 |
+
def __init__(self, path=""):
|
9 |
+
# Load the Orpheus model and tokenizer
|
10 |
+
self.model_name = "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit"
|
11 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
12 |
+
self.model_name,
|
13 |
+
torch_dtype=torch.bfloat16
|
14 |
+
)
|
15 |
+
|
16 |
+
# Move model to GPU if available
|
17 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
self.model.to(self.device)
|
19 |
+
|
20 |
+
# Load tokenizer
|
21 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
22 |
+
|
23 |
+
# Load SNAC model for audio decoding
|
24 |
+
self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
|
25 |
+
self.snac_model.to(self.device)
|
26 |
+
|
27 |
+
# Special tokens
|
28 |
+
self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
|
29 |
+
self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
|
30 |
+
self.padding_token = 128263
|
31 |
+
self.start_audio_token = 128257 # Start of Audio token
|
32 |
+
self.end_audio_token = 128258 # End of Audio token
|
33 |
+
|
34 |
+
print(f"Model loaded on {self.device}")
|
35 |
+
|
36 |
+
def preprocess(self, data):
|
37 |
+
"""
|
38 |
+
Preprocess input data before inference
|
39 |
+
"""
|
40 |
+
inputs = data.pop("inputs", data)
|
41 |
+
|
42 |
+
# Extract parameters from request
|
43 |
+
text = inputs.get("text", "")
|
44 |
+
voice = inputs.get("voice", "tara")
|
45 |
+
temperature = float(inputs.get("temperature", 0.6))
|
46 |
+
top_p = float(inputs.get("top_p", 0.95))
|
47 |
+
max_new_tokens = int(inputs.get("max_new_tokens", 1200))
|
48 |
+
repetition_penalty = float(inputs.get("repetition_penalty", 1.1))
|
49 |
+
|
50 |
+
# Format prompt with voice
|
51 |
+
prompt = f"{voice}: {text}"
|
52 |
+
|
53 |
+
# Tokenize
|
54 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
55 |
+
|
56 |
+
# Add special tokens
|
57 |
+
modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1)
|
58 |
+
|
59 |
+
# No need for padding as we're processing a single sequence
|
60 |
+
input_ids = modified_input_ids.to(self.device)
|
61 |
+
attention_mask = torch.ones_like(input_ids)
|
62 |
+
|
63 |
+
return {
|
64 |
+
"input_ids": input_ids,
|
65 |
+
"attention_mask": attention_mask,
|
66 |
+
"temperature": temperature,
|
67 |
+
"top_p": top_p,
|
68 |
+
"max_new_tokens": max_new_tokens,
|
69 |
+
"repetition_penalty": repetition_penalty
|
70 |
+
}
|
71 |
+
|
72 |
+
def inference(self, inputs):
|
73 |
+
"""
|
74 |
+
Run model inference on the preprocessed inputs
|
75 |
+
"""
|
76 |
+
# Extract parameters
|
77 |
+
input_ids = inputs["input_ids"]
|
78 |
+
attention_mask = inputs["attention_mask"]
|
79 |
+
temperature = inputs["temperature"]
|
80 |
+
top_p = inputs["top_p"]
|
81 |
+
max_new_tokens = inputs["max_new_tokens"]
|
82 |
+
repetition_penalty = inputs["repetition_penalty"]
|
83 |
+
|
84 |
+
# Generate output tokens
|
85 |
+
with torch.no_grad():
|
86 |
+
generated_ids = self.model.generate(
|
87 |
+
input_ids=input_ids,
|
88 |
+
attention_mask=attention_mask,
|
89 |
+
max_new_tokens=max_new_tokens,
|
90 |
+
do_sample=True,
|
91 |
+
temperature=temperature,
|
92 |
+
top_p=top_p,
|
93 |
+
repetition_penalty=repetition_penalty,
|
94 |
+
num_return_sequences=1,
|
95 |
+
eos_token_id=self.end_audio_token,
|
96 |
+
)
|
97 |
+
|
98 |
+
return generated_ids
|
99 |
+
|
100 |
+
def postprocess(self, generated_ids):
|
101 |
+
"""
|
102 |
+
Process generated tokens into audio
|
103 |
+
"""
|
104 |
+
# Find Start of Audio token
|
105 |
+
token_indices = (generated_ids == self.start_audio_token).nonzero(as_tuple=True)
|
106 |
+
|
107 |
+
if len(token_indices[1]) > 0:
|
108 |
+
last_occurrence_idx = token_indices[1][-1].item()
|
109 |
+
cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
|
110 |
+
else:
|
111 |
+
cropped_tensor = generated_ids
|
112 |
+
|
113 |
+
# Remove End of Audio tokens
|
114 |
+
processed_rows = []
|
115 |
+
for row in cropped_tensor:
|
116 |
+
masked_row = row[row != self.end_audio_token]
|
117 |
+
processed_rows.append(masked_row)
|
118 |
+
|
119 |
+
# Prepare audio codes
|
120 |
+
code_lists = []
|
121 |
+
for row in processed_rows:
|
122 |
+
row_length = row.size(0)
|
123 |
+
# Ensure length is multiple of 7 for SNAC
|
124 |
+
new_length = (row_length // 7) * 7
|
125 |
+
trimmed_row = row[:new_length]
|
126 |
+
trimmed_row = [t.item() - 128266 for t in trimmed_row] # Adjust token values
|
127 |
+
code_lists.append(trimmed_row)
|
128 |
+
|
129 |
+
# Generate audio from codes
|
130 |
+
audio_samples = []
|
131 |
+
for code_list in code_lists:
|
132 |
+
audio = self.redistribute_codes(code_list)
|
133 |
+
audio_samples.append(audio)
|
134 |
+
|
135 |
+
# Return first (and only) audio sample
|
136 |
+
audio_sample = audio_samples[0].detach().squeeze().cpu().numpy()
|
137 |
+
|
138 |
+
# Convert to base64 for transmission
|
139 |
+
import base64
|
140 |
+
import io
|
141 |
+
import wave
|
142 |
+
|
143 |
+
# Convert float32 array to int16 for WAV format
|
144 |
+
audio_int16 = (audio_sample * 32767).astype(np.int16)
|
145 |
+
|
146 |
+
# Create WAV in memory
|
147 |
+
with io.BytesIO() as wav_io:
|
148 |
+
with wave.open(wav_io, 'wb') as wav_file:
|
149 |
+
wav_file.setnchannels(1) # Mono
|
150 |
+
wav_file.setsampwidth(2) # 16-bit
|
151 |
+
wav_file.setframerate(24000) # 24kHz
|
152 |
+
wav_file.writeframes(audio_int16.tobytes())
|
153 |
+
wav_data = wav_io.getvalue()
|
154 |
+
|
155 |
+
# Encode as base64
|
156 |
+
audio_b64 = base64.b64encode(wav_data).decode('utf-8')
|
157 |
+
|
158 |
+
return {
|
159 |
+
"audio_b64": audio_b64,
|
160 |
+
"sample_rate": 24000
|
161 |
+
}
|
162 |
+
|
163 |
+
def redistribute_codes(self, code_list):
|
164 |
+
"""
|
165 |
+
Reorganize tokens for SNAC decoding
|
166 |
+
"""
|
167 |
+
layer_1 = [] # Coarsest layer
|
168 |
+
layer_2 = [] # Intermediate layer
|
169 |
+
layer_3 = [] # Finest layer
|
170 |
+
|
171 |
+
num_groups = len(code_list) // 7
|
172 |
+
for i in range(num_groups):
|
173 |
+
idx = 7 * i
|
174 |
layer_1.append(code_list[idx])
|
175 |
layer_2.append(code_list[idx + 1] - 4096)
|
176 |
layer_3.append(code_list[idx + 2] - (2 * 4096))
|
|
|
178 |
layer_2.append(code_list[idx + 4] - (4 * 4096))
|
179 |
layer_3.append(code_list[idx + 5] - (5 * 4096))
|
180 |
layer_3.append(code_list[idx + 6] - (6 * 4096))
|
181 |
+
|
182 |
+
codes = [
|
183 |
+
torch.tensor(layer_1).unsqueeze(0).to(self.device),
|
184 |
+
torch.tensor(layer_2).unsqueeze(0).to(self.device),
|
185 |
+
torch.tensor(layer_3).unsqueeze(0).to(self.device)
|
186 |
+
]
|
187 |
+
|
188 |
+
# Decode audio
|
189 |
+
audio_hat = self.snac_model.decode(codes)
|
190 |
+
return audio_hat
|
191 |
+
|
192 |
+
def __call__(self, data):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
"""
|
194 |
+
Main entry point for the handler
|
|
|
195 |
"""
|
196 |
+
preprocessed_inputs = self.preprocess(data)
|
197 |
+
model_outputs = self.inference(preprocessed_inputs)
|
198 |
+
response = self.postprocess(model_outputs)
|
199 |
+
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|