hypaai commited on
Commit
59a5c2a
·
verified ·
1 Parent(s): dc6f6bd

Updated handler.py with claude

Browse files
Files changed (1) hide show
  1. handler.py +186 -193
handler.py CHANGED
@@ -1,20 +1,176 @@
 
1
  import torch
2
  import numpy as np
3
- import soundfile as sf
4
- import io
5
- import os
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
- from snac import SNAC # Assuming SNAC is installed via requirements.txt
8
-
9
- # --- Helper Function (can be outside or inside the class) ---
10
- def redistribute_codes_static(code_list):
11
- """ Reorganizes the flattened token list into three separate layers for SNAC. """
12
- layer_1, layer_2, layer_3 = [], [], []
13
- num_groups = len(code_list) // 7 # Use floor division
14
 
15
- for i in range(num_groups):
16
- idx = 7 * i
17
- try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  layer_1.append(code_list[idx])
19
  layer_2.append(code_list[idx + 1] - 4096)
20
  layer_3.append(code_list[idx + 2] - (2 * 4096))
@@ -22,185 +178,22 @@ def redistribute_codes_static(code_list):
22
  layer_2.append(code_list[idx + 4] - (4 * 4096))
23
  layer_3.append(code_list[idx + 5] - (5 * 4096))
24
  layer_3.append(code_list[idx + 6] - (6 * 4096))
25
- except IndexError:
26
- print(f"Warning: Index out of range during code redistribution at group {i}. Code list length: {len(code_list)}")
27
- break # Stop processing if indices go out of bounds
28
-
29
- # Ensure tensors are created even if empty
30
- codes = [
31
- torch.tensor(layer_1 or [0]).unsqueeze(0).long(), # Use long dtype, provide default if empty
32
- torch.tensor(layer_2 or [0]).unsqueeze(0).long(),
33
- torch.tensor(layer_3 or [0]).unsqueeze(0).long()
34
- ]
35
- # Remove the default [0] if lists were actually populated
36
- if layer_1: codes[0] = torch.tensor(layer_1).unsqueeze(0).long()
37
- if layer_2: codes[1] = torch.tensor(layer_2).unsqueeze(0).long()
38
- if layer_3: codes[2] = torch.tensor(layer_3).unsqueeze(0).long()
39
-
40
- return codes
41
-
42
- # --- Endpoint Handler Class ---
43
- class EndpointHandler():
44
- def __init__(self, path=""):
45
- """
46
- Initializes the handler. Loads both Orpheus LLM and SNAC Vocoder.
47
- 'path' points to the directory containing the Orpheus model files specified in the endpoint config.
48
- """
49
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
50
- print(f"Using device: {self.device}")
51
-
52
- # Define Model Names/Paths
53
- # Orpheus LLM path is determined by the endpoint configuration ('path' variable)
54
- orpheus_model_path = path if path else "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit"
55
- snac_model_name = "hubertsiuzdak/snac_24khz"
56
-
57
- # Define Special Token IDs (matching your script)
58
- self.start_human_token_id = 128259
59
- self.end_text_token_id = 128009
60
- self.end_human_token_id = 128260
61
- # self.padding_token_id = 128263 # Not needed for single sequence generation
62
- self.start_audio_token_id = 128257
63
- self.end_audio_token_id = 128258
64
- self.audio_code_offset = 128266
65
-
66
- # Define sampling rate
67
- self.sampling_rate = 24000
68
-
69
- try:
70
- # Load Orpheus LLM and Tokenizer
71
- print(f"Loading Orpheus tokenizer from: {orpheus_model_path}")
72
- self.tokenizer = AutoTokenizer.from_pretrained(orpheus_model_path)
73
- print(f"Loading Orpheus model from: {orpheus_model_path}")
74
- self.model = AutoModelForCausalLM.from_pretrained(
75
- orpheus_model_path,
76
- torch_dtype=torch.bfloat16 # Use bfloat16 as in your script
77
- )
78
- self.model.to(self.device)
79
- self.model.eval() # Set model to evaluation mode
80
- print("Orpheus model and tokenizer loaded successfully.")
81
-
82
- # Load SNAC Vocoder
83
- print(f"Loading SNAC model from: {snac_model_name}")
84
- self.snac_model = SNAC.from_pretrained(snac_model_name)
85
- self.snac_model.to(self.device) # Move SNAC to the same device
86
- self.snac_model.eval() # Set model to evaluation mode
87
- print("SNAC model loaded successfully.")
88
-
89
- except Exception as e:
90
- print(f"Error during model loading: {e}")
91
- raise RuntimeError(f"Failed to load models.", e)
92
-
93
- def __call__(self, data: dict) -> bytes:
94
  """
95
- Handles incoming API requests for TTS inference.
96
- Expects data['inputs'] (text) and optionally data['parameters']
97
  """
98
- try:
99
- # --- Get Inputs & Parameters ---
100
- text = data.pop("inputs", None)
101
- if text is None:
102
- raise ValueError("Missing 'inputs' key in request data")
103
-
104
- parameters = data.pop("parameters", {})
105
- # Default voice if not provided
106
- voice = parameters.get("voice", "Eniola")
107
- # Default generation parameters (merge with provided ones)
108
- gen_params = {
109
- "max_new_tokens": 1200,
110
- "do_sample": True,
111
- "temperature": 0.6,
112
- "top_p": 0.95,
113
- "repetition_penalty": 1.1,
114
- "num_return_sequences": 1,
115
- "eos_token_id": self.end_audio_token_id,
116
- **parameters # Overwrite defaults with user params
117
- }
118
- # Remove non-generate params if they were passed
119
- gen_params.pop("voice", None)
120
-
121
- print(f"Received request: text='{text[:50]}...', voice='{voice}', params={gen_params}")
122
-
123
- # --- Preprocess Text ---
124
- prompt = f"{voice}: {text}"
125
- input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
126
-
127
- # Add special tokens: SOH + Input Tokens + EOT + EOH
128
- start_token_tensor = torch.tensor([[self.start_human_token_id]], dtype=torch.int64)
129
- end_tokens_tensor = torch.tensor([[self.end_text_token_id, self.end_human_token_id]], dtype=torch.int64)
130
- processed_input_ids = torch.cat([start_token_tensor, input_ids, end_tokens_tensor], dim=1).to(self.device)
131
-
132
- # Create attention mask (all ones for single, unpadded sequence)
133
- attention_mask = torch.ones_like(processed_input_ids).to(self.device)
134
-
135
- print(f"Processed input shape: {processed_input_ids.shape}")
136
-
137
- # --- Generate Audio Codes (LLM Inference) ---
138
- with torch.no_grad():
139
- generated_ids = self.model.generate(
140
- input_ids=processed_input_ids,
141
- attention_mask=attention_mask,
142
- **gen_params
143
- )
144
- print(f"Generated IDs shape: {generated_ids.shape}")
145
-
146
- # --- Process Generated Tokens (Extract Audio Codes) ---
147
- # Find the last Start of Audio token
148
- soa_indices = (generated_ids[0] == self.start_audio_token_id).nonzero(as_tuple=True)[0]
149
- if len(soa_indices) == 0:
150
- print("Warning: Start of Audio token (128257) not found in generated sequence!")
151
- # Handle this case: maybe return error, or try processing from start?
152
- # For now, let's assume it might still contain codes and try processing all generated *new* tokens
153
- start_idx = processed_input_ids.shape[1] # Start after the input prompt
154
- else:
155
- start_idx = soa_indices[-1].item() + 1 # Start after the last SOA token
156
-
157
- # Extract potential audio codes (after last SOA or after input)
158
- cropped_tokens = generated_ids[0, start_idx:]
159
-
160
- # Remove End of Audio tokens
161
- audio_codes_raw = cropped_tokens[cropped_tokens != self.end_audio_token_id]
162
- print(f"Extracted raw audio codes count: {len(audio_codes_raw)}")
163
-
164
- if len(audio_codes_raw) == 0:
165
- raise ValueError("No audio codes generated or extracted after processing.")
166
-
167
- # --- Prepare Codes for SNAC Vocoder ---
168
- # Adjust token values
169
- adjusted_codes = [t.item() - self.audio_code_offset for t in audio_codes_raw]
170
-
171
- # Trim to multiple of 7
172
- num_codes = len(adjusted_codes)
173
- valid_length = (num_codes // 7) * 7
174
- if valid_length == 0:
175
- raise ValueError(f"Not enough audio codes ({num_codes}) to form a multiple of 7 after processing.")
176
- trimmed_codes = adjusted_codes[:valid_length]
177
- print(f"Trimmed adjusted audio codes count: {len(trimmed_codes)}")
178
-
179
- # --- Redistribute Codes ---
180
- # Use static method or instance method, ensure tensors are on correct device
181
- snac_input_codes = redistribute_codes_static(trimmed_codes)
182
- snac_input_codes = [layer.to(self.device) for layer in snac_input_codes]
183
-
184
- # --- Decode Audio (SNAC Inference) ---
185
- print("Decoding audio with SNAC...")
186
- with torch.no_grad():
187
- audio_hat = self.snac_model.decode(snac_input_codes)
188
- print(f"Decoded audio tensor shape: {audio_hat.shape}") # Should be [1, 1, num_samples]
189
-
190
- # --- Postprocess Audio ---
191
- # Move to CPU, remove batch/channel dims, convert to numpy
192
- audio_waveform = audio_hat.detach().squeeze().cpu().numpy()
193
-
194
- # --- Convert to WAV Bytes ---
195
- buffer = io.BytesIO()
196
- sf.write(buffer, audio_waveform, self.sampling_rate, format='WAV')
197
- buffer.seek(0)
198
- wav_bytes = buffer.read()
199
-
200
- print(f"Generated {len(wav_bytes)} bytes of WAV audio.")
201
- return wav_bytes
202
-
203
- except Exception as e:
204
- print(f"Error during inference call: {e}")
205
- # Re-raise for endpoint framework
206
- raise RuntimeError(f"Inference failed: {e}")
 
1
+ import os
2
  import torch
3
  import numpy as np
 
 
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from snac import SNAC
 
 
 
 
 
 
6
 
7
+ class EndpointHandler:
8
+ def __init__(self, path=""):
9
+ # Load the Orpheus model and tokenizer
10
+ self.model_name = "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit"
11
+ self.model = AutoModelForCausalLM.from_pretrained(
12
+ self.model_name,
13
+ torch_dtype=torch.bfloat16
14
+ )
15
+
16
+ # Move model to GPU if available
17
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ self.model.to(self.device)
19
+
20
+ # Load tokenizer
21
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
22
+
23
+ # Load SNAC model for audio decoding
24
+ self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
25
+ self.snac_model.to(self.device)
26
+
27
+ # Special tokens
28
+ self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
29
+ self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
30
+ self.padding_token = 128263
31
+ self.start_audio_token = 128257 # Start of Audio token
32
+ self.end_audio_token = 128258 # End of Audio token
33
+
34
+ print(f"Model loaded on {self.device}")
35
+
36
+ def preprocess(self, data):
37
+ """
38
+ Preprocess input data before inference
39
+ """
40
+ inputs = data.pop("inputs", data)
41
+
42
+ # Extract parameters from request
43
+ text = inputs.get("text", "")
44
+ voice = inputs.get("voice", "tara")
45
+ temperature = float(inputs.get("temperature", 0.6))
46
+ top_p = float(inputs.get("top_p", 0.95))
47
+ max_new_tokens = int(inputs.get("max_new_tokens", 1200))
48
+ repetition_penalty = float(inputs.get("repetition_penalty", 1.1))
49
+
50
+ # Format prompt with voice
51
+ prompt = f"{voice}: {text}"
52
+
53
+ # Tokenize
54
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
55
+
56
+ # Add special tokens
57
+ modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1)
58
+
59
+ # No need for padding as we're processing a single sequence
60
+ input_ids = modified_input_ids.to(self.device)
61
+ attention_mask = torch.ones_like(input_ids)
62
+
63
+ return {
64
+ "input_ids": input_ids,
65
+ "attention_mask": attention_mask,
66
+ "temperature": temperature,
67
+ "top_p": top_p,
68
+ "max_new_tokens": max_new_tokens,
69
+ "repetition_penalty": repetition_penalty
70
+ }
71
+
72
+ def inference(self, inputs):
73
+ """
74
+ Run model inference on the preprocessed inputs
75
+ """
76
+ # Extract parameters
77
+ input_ids = inputs["input_ids"]
78
+ attention_mask = inputs["attention_mask"]
79
+ temperature = inputs["temperature"]
80
+ top_p = inputs["top_p"]
81
+ max_new_tokens = inputs["max_new_tokens"]
82
+ repetition_penalty = inputs["repetition_penalty"]
83
+
84
+ # Generate output tokens
85
+ with torch.no_grad():
86
+ generated_ids = self.model.generate(
87
+ input_ids=input_ids,
88
+ attention_mask=attention_mask,
89
+ max_new_tokens=max_new_tokens,
90
+ do_sample=True,
91
+ temperature=temperature,
92
+ top_p=top_p,
93
+ repetition_penalty=repetition_penalty,
94
+ num_return_sequences=1,
95
+ eos_token_id=self.end_audio_token,
96
+ )
97
+
98
+ return generated_ids
99
+
100
+ def postprocess(self, generated_ids):
101
+ """
102
+ Process generated tokens into audio
103
+ """
104
+ # Find Start of Audio token
105
+ token_indices = (generated_ids == self.start_audio_token).nonzero(as_tuple=True)
106
+
107
+ if len(token_indices[1]) > 0:
108
+ last_occurrence_idx = token_indices[1][-1].item()
109
+ cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
110
+ else:
111
+ cropped_tensor = generated_ids
112
+
113
+ # Remove End of Audio tokens
114
+ processed_rows = []
115
+ for row in cropped_tensor:
116
+ masked_row = row[row != self.end_audio_token]
117
+ processed_rows.append(masked_row)
118
+
119
+ # Prepare audio codes
120
+ code_lists = []
121
+ for row in processed_rows:
122
+ row_length = row.size(0)
123
+ # Ensure length is multiple of 7 for SNAC
124
+ new_length = (row_length // 7) * 7
125
+ trimmed_row = row[:new_length]
126
+ trimmed_row = [t.item() - 128266 for t in trimmed_row] # Adjust token values
127
+ code_lists.append(trimmed_row)
128
+
129
+ # Generate audio from codes
130
+ audio_samples = []
131
+ for code_list in code_lists:
132
+ audio = self.redistribute_codes(code_list)
133
+ audio_samples.append(audio)
134
+
135
+ # Return first (and only) audio sample
136
+ audio_sample = audio_samples[0].detach().squeeze().cpu().numpy()
137
+
138
+ # Convert to base64 for transmission
139
+ import base64
140
+ import io
141
+ import wave
142
+
143
+ # Convert float32 array to int16 for WAV format
144
+ audio_int16 = (audio_sample * 32767).astype(np.int16)
145
+
146
+ # Create WAV in memory
147
+ with io.BytesIO() as wav_io:
148
+ with wave.open(wav_io, 'wb') as wav_file:
149
+ wav_file.setnchannels(1) # Mono
150
+ wav_file.setsampwidth(2) # 16-bit
151
+ wav_file.setframerate(24000) # 24kHz
152
+ wav_file.writeframes(audio_int16.tobytes())
153
+ wav_data = wav_io.getvalue()
154
+
155
+ # Encode as base64
156
+ audio_b64 = base64.b64encode(wav_data).decode('utf-8')
157
+
158
+ return {
159
+ "audio_b64": audio_b64,
160
+ "sample_rate": 24000
161
+ }
162
+
163
+ def redistribute_codes(self, code_list):
164
+ """
165
+ Reorganize tokens for SNAC decoding
166
+ """
167
+ layer_1 = [] # Coarsest layer
168
+ layer_2 = [] # Intermediate layer
169
+ layer_3 = [] # Finest layer
170
+
171
+ num_groups = len(code_list) // 7
172
+ for i in range(num_groups):
173
+ idx = 7 * i
174
  layer_1.append(code_list[idx])
175
  layer_2.append(code_list[idx + 1] - 4096)
176
  layer_3.append(code_list[idx + 2] - (2 * 4096))
 
178
  layer_2.append(code_list[idx + 4] - (4 * 4096))
179
  layer_3.append(code_list[idx + 5] - (5 * 4096))
180
  layer_3.append(code_list[idx + 6] - (6 * 4096))
181
+
182
+ codes = [
183
+ torch.tensor(layer_1).unsqueeze(0).to(self.device),
184
+ torch.tensor(layer_2).unsqueeze(0).to(self.device),
185
+ torch.tensor(layer_3).unsqueeze(0).to(self.device)
186
+ ]
187
+
188
+ # Decode audio
189
+ audio_hat = self.snac_model.decode(codes)
190
+ return audio_hat
191
+
192
+ def __call__(self, data):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  """
194
+ Main entry point for the handler
 
195
  """
196
+ preprocessed_inputs = self.preprocess(data)
197
+ model_outputs = self.inference(preprocessed_inputs)
198
+ response = self.postprocess(model_outputs)
199
+ return response