voidful commited on
Commit
6cb34fb
·
verified ·
1 Parent(s): ea21fb3

Update processing_gemma3_omni.py

Browse files
Files changed (1) hide show
  1. processing_gemma3_omni.py +214 -265
processing_gemma3_omni.py CHANGED
@@ -1,9 +1,9 @@
1
  import re
2
- from typing import List, Optional, Union, Dict, Any, Tuple
3
 
4
  import math
5
  import numpy as np
6
- import scipy.signal # Keep for resample_poly
7
  import torch
8
  from torch.nn.utils.rnn import pad_sequence
9
 
@@ -15,28 +15,31 @@ AudioInput = Union[np.ndarray, List[float], Tuple[np.ndarray, int]]
15
 
16
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
17
  from transformers.feature_extraction_utils import BatchFeature
18
- # from transformers.image_utils import make_nested_list_of_images # Not used in this file
19
- from transformers.processing_utils import ProcessorMixin, \
20
- ProcessingKwargs # Removed ImagesKwargs, Unpack for simplicity as they are not fully defined here
21
  from transformers.utils import TensorType, to_py_obj, logging
 
 
22
 
23
- # Constants
 
24
  DEFAULT_SAMPLING_RATE = 16000
25
  DEFAULT_N_FFT = 512
26
- DEFAULT_WIN_LENGTH = 400
27
- DEFAULT_HOP_LENGTH = 160
28
  DEFAULT_N_MELS = 80
29
- DEFAULT_COMPRESSION_RATE = 4 # As in FeatureExtractor init
30
- DEFAULT_QFORMER_RATE = 2 # As in FeatureExtractor init
31
- DEFAULT_FEAT_STRIDE = 4 # As in FeatureExtractor init, affects 'frames' if used
32
- IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>" # From user
33
- AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>" # From user
34
- DEFAULT_MAX_LENGTH = 16384 # From user
35
  LOG_MEL_CLIP_EPSILON = 1e-5
36
 
37
  logger = logging.get_logger(__name__)
38
 
39
 
 
 
40
  def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
41
  fmax: Optional[float] = None) -> np.ndarray:
42
  """Create Mel filterbank for audio processing."""
@@ -45,79 +48,66 @@ def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: flo
45
  if fmin >= fmax:
46
  raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
47
 
48
- def hz_to_mel(f: float) -> float: # Using HTK formula (as in librosa default)
49
  return 2595.0 * math.log10(1 + f / 700.0)
50
 
51
  def mel_to_hz(mel: float) -> float:
52
- return 700.0 * (10 ** (mel / 2595.0) - 1)
53
 
54
  mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
55
  freq_points = mel_to_hz(mel_points)
56
-
57
- # Ensure freq_points are within the Nyquist limit
58
  freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
59
-
60
- # fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sampling_rate) # Frequencies for each FFT bin center
61
- # bins = np.searchsorted(fftfreqs, freq_points) # More robust way to find bins
62
- bins = np.floor((n_fft / 2.0) * freq_points / (sampling_rate / 2.0)).astype(int) # Simplified from librosa
63
- bins = np.clip(bins, 0, n_fft // 2) # Max index for rfft output (n_fft//2 + 1 bins, so max index is n_fft//2)
64
 
65
  filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
66
- for m in range(n_mels): # Iterate 0 to n_mels-1 for m-th filter
67
  left, center, right = bins[m], bins[m + 1], bins[m + 2]
68
-
 
69
  if center > left:
70
- filterbank[m, left:center + 1] = (np.arange(left, center + 1) - left) / (
71
- center - left) # Inclusive center for peak
72
  if right > center:
73
- # The peak is at 'center'. So the downward slope starts from 'center'.
74
- filterbank[m, center:right + 1] = (right - np.arange(center, right + 1)) / (
75
- right - center) # Inclusive center
76
- # Ensure the peak of the triangle is 1, and handle cases where left=center or center=right
77
- # This simplified version might need librosa.filters.mel for full robustness if edge cases are hit.
78
- # For now, ensuring distinct points for slopes:
79
- if center > left and center < right: # Standard triangle
80
- # Ramp up
81
- idxs_up = np.arange(left, center + 1) # include center
82
- filterbank[m, idxs_up] = (idxs_up - left) / (center - left)
83
- # Ramp down
84
- idxs_down = np.arange(center, right + 1) # include center
85
- filterbank[m, idxs_down] = (right - idxs_down) / (right - center)
86
- elif center > left: # only ramp up (right part is flat or non-existent)
87
- idxs_up = np.arange(left, center + 1)
88
- filterbank[m, idxs_up] = (idxs_up - left) / (center - left)
89
- elif center < right: # only ramp down (left part is flat or non-existent)
90
- idxs_down = np.arange(center, right + 1)
91
- filterbank[m, idxs_down] = (right - idxs_down) / (right - center)
92
 
93
- return filterbank
94
 
 
95
 
 
 
96
  class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
97
- model_input_names = ["audio_values", "audio_attention_mask"] # What this extractor produces for the model
98
 
99
  def __init__(
100
  self,
101
  compression_rate: int = DEFAULT_COMPRESSION_RATE,
102
  qformer_rate: int = DEFAULT_QFORMER_RATE,
103
  feat_stride: int = DEFAULT_FEAT_STRIDE,
104
- sampling_rate: int = DEFAULT_SAMPLING_RATE, # Target sampling rate
105
  n_fft: int = DEFAULT_N_FFT,
106
- win_length: Optional[int] = None, # Default to n_fft if None
107
- hop_length: Optional[int] = None, # Default to win_length // 4 if None
108
  n_mels: int = DEFAULT_N_MELS,
109
  f_min: float = 0.0,
110
  f_max: Optional[float] = None,
111
- padding_value: float = 0.0, # For pad_sequence of mels
112
  **kwargs
113
  ):
114
- # feature_size is the number of features per frame (n_mels)
115
  super().__init__(feature_size=n_mels, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
116
 
117
  self.compression_rate = compression_rate
118
  self.qformer_rate = qformer_rate
119
- self.feat_stride = feat_stride # Used for 'frames' calculation, purpose might be downstream
120
-
121
  self.n_fft = n_fft
122
  self.win_length = win_length if win_length is not None else n_fft
123
  self.hop_length = hop_length if hop_length is not None else self.win_length // 4
@@ -130,26 +120,24 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
130
  f"win_length ({self.win_length}) is greater than n_fft ({self.n_fft}). "
131
  f"For FFT computation, the window will effectively be truncated or the signal zero-padded to n_fft length."
132
  )
133
- self.window = scipy.signal.get_window("hann", self.win_length).astype(np.float32) # Using Hann window
134
  self.mel_filterbank = create_mel_filterbank(
135
  self.sampling_rate, self.n_fft, self.n_mels, fmin=self.f_min, fmax=self.f_max
136
- ).T # Transpose to (n_fft // 2 + 1, n_mels)
137
 
138
  def __call__(
139
  self,
140
  audios: Union[AudioInput, List[AudioInput]],
141
- sampling_rate: Optional[int] = None, # SR of the input audios if they are raw arrays
142
  return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
143
  ) -> BatchFeature:
144
-
145
  if not isinstance(audios, list):
146
  audios = [audios]
147
 
148
  processed_mel_spectrograms: List[torch.Tensor] = []
149
  actual_mel_lengths: List[int] = []
150
-
151
- # Optional downstream calculation values, kept if needed by other parts of Gemma3OmniProcessor
152
- downstream_sizes_for_token_calc: List[torch.Tensor] = []
153
  downstream_frames_scaled_for_token_calc: List[int] = []
154
 
155
  for audio_input_item in audios:
@@ -166,73 +154,58 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
166
  "sampling_rate must be provided if audio inputs are raw numpy arrays or lists."
167
  )
168
  source_sr = sampling_rate
169
- else: # Add more robust loading here if using file paths or bytes
170
  raise TypeError(
171
  f"Unsupported audio input type: {type(audio_input_item)}. "
172
  "This extractor expects np.ndarray, list of floats, or Tuple[np.ndarray, int indicating SR]."
173
  )
174
 
175
  processed_wav = self._preprocess_audio(current_wav_array, source_sr)
176
- # mel_spectrogram has shape (frame_count, n_mels)
177
- mel_spectrogram = self._compute_log_mel_spectrogram(processed_wav)
178
-
179
- feature_tensor = torch.from_numpy(mel_spectrogram) # Already float32 from _compute_log_mel_spectrogram
180
  processed_mel_spectrograms.append(feature_tensor)
181
- actual_mel_lengths.append(feature_tensor.shape[0]) # frame_count for this item
182
 
183
- # These calculations seem related to determining the number of special tokens in the prompt
184
  downstream_sizes_for_token_calc.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
185
  downstream_frames_scaled_for_token_calc.append(feature_tensor.shape[0] * self.feat_stride)
186
 
187
- # Pad the mel spectrograms
188
- # audio_values will have shape (Batch, Max_frame_count, n_mels)
189
  audio_values = pad_sequence(processed_mel_spectrograms, batch_first=True, padding_value=self.padding_value)
190
-
191
- # Create attention mask corresponding to audio_values
192
- max_mel_len = audio_values.shape[1] # Max_frame_count across the batch
193
  lengths_tensor = torch.tensor(actual_mel_lengths, dtype=torch.long)
194
-
195
- audio_attention_mask = torch.arange(max_mel_len).unsqueeze(0).expand(len(audios),
196
- -1) < lengths_tensor.unsqueeze(1)
197
-
198
  output_data = {
199
- "audio_values": audio_values, # (B, Max_T_mel, N_Mels) -> Input to Conformer
200
- "audio_attention_mask": audio_attention_mask # (B, Max_T_mel) -> Mask for Conformer input
201
  }
202
-
203
- # Include these if they are used by Gemma3OmniProcessor for prompt construction
204
- if downstream_sizes_for_token_calc: # Renamed to clarify purpose
205
- output_data["audio_token_calc_sizes"] = torch.stack(downstream_sizes_for_token_calc)
206
- # if downstream_frames_scaled_for_token_calc: # Example if needed
207
- # output_data["audio_token_calc_frames_scaled"] = torch.tensor(downstream_frames_scaled_for_token_calc)
208
-
209
  return BatchFeature(data=output_data, tensor_type=return_tensors)
210
 
211
  def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray:
212
  if wav.dtype not in [np.float32, np.float64]:
213
- # Assuming int audio needs normalization to [-1, 1]
214
  if np.issubdtype(wav.dtype, np.integer):
215
  max_val = np.iinfo(wav.dtype).max
216
  wav = wav.astype(np.float32) / max_val
217
- else: # Other non-float types
218
  wav = wav.astype(np.float32)
219
-
220
  if wav.ndim > 1:
221
- wav = wav.mean(axis=0) # Convert to mono
222
-
223
  if source_sr != self.sampling_rate:
224
- # Using scipy.signal.resample_poly for potentially higher quality resampling
225
- # It requires integer up/down factors.
226
  gcd = math.gcd(self.sampling_rate, source_sr)
227
  up_factor = self.sampling_rate // gcd
228
  down_factor = source_sr // gcd
229
- if up_factor != down_factor: # Only resample if ratio is not 1
230
- logger.info(f"Resampling audio from {source_sr} Hz to {self.sampling_rate} Hz.")
231
- wav = scipy.signal.resample_poly(wav, up=up_factor, down=down_factor)
232
-
233
- # Peak normalization
234
  norm_factor = np.abs(wav).max()
235
- if norm_factor > 1e-9: # Avoid division by zero/small numbers for silent/near-silent audio
236
  wav = wav / norm_factor
237
  return wav
238
 
@@ -241,12 +214,9 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
241
  padding = self.win_length - len(wav)
242
  wav = np.pad(wav, (0, padding), mode='constant', constant_values=0.0)
243
 
244
- # Using librosa-like STFT parameters where n_fft, hop_length, win_length are explicit
245
- # Manual framing and windowing:
246
  num_frames = 1 + (len(wav) - self.win_length) // self.hop_length
247
  if num_frames <= 0:
248
- logger.warning(
249
- f"Audio of length {len(wav)} is too short to produce frames with win_length {self.win_length} and hop_length {self.hop_length}. Returning empty mel spectrogram.")
250
  return np.zeros((0, self.n_mels), dtype=np.float32)
251
 
252
  frames = np.lib.stride_tricks.as_strided(
@@ -255,153 +225,144 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
255
  strides=(wav.strides[0] * self.hop_length, wav.strides[0]),
256
  writeable=False
257
  )
258
-
259
  windowed_frames = frames * self.window
260
-
261
- stft_matrix = np.fft.rfft(windowed_frames, n=self.n_fft, axis=-1) # Shape (num_frames, n_fft // 2 + 1)
262
- powers = np.abs(stft_matrix) ** 2
263
-
264
- mel_spectrogram = np.dot(powers, self.mel_filterbank) # Shape (num_frames, n_mels)
265
-
266
  mel_spectrogram = np.clip(mel_spectrogram, LOG_MEL_CLIP_EPSILON, None)
267
  log_mel_spectrogram = np.log(mel_spectrogram)
268
-
269
  return log_mel_spectrogram.astype(np.float32)
270
 
271
  def _calculate_embed_length(self, frame_count: int) -> int:
272
- # This calculation is likely for determining the number of special tokens
273
- # to insert in the text prompt, based on the audio length.
274
  compressed = math.ceil(frame_count / self.compression_rate)
275
  return math.ceil(compressed / self.qformer_rate)
276
 
277
-
278
- # --- Gemma3ProcessorKwargs and Gemma3ImagesKwargs would be defined here if needed ---
279
- # For this fix, focusing on Gemma3AudioFeatureExtractor and Gemma3OmniProcessor interactions
280
- class Gemma3DummyProcessorKwargs(ProcessingKwargs, total=False): # Dummy for testing
281
  images_kwargs: Dict[str, Any]
282
  audio_kwargs: Dict[str, Any]
283
- text_kwargs: Dict[str, Any] # Added for completeness
284
  _defaults = {
285
  "text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH},
286
  "images_kwargs": {},
287
  "audio_kwargs": {}
288
  }
289
 
290
-
291
  class Gemma3OmniProcessor(ProcessorMixin):
292
  attributes = ["image_processor", "audio_processor", "tokenizer"]
293
- # Define image_processor_class and tokenizer_class if they are standard,
294
- # or handle their instantiation/passing carefully.
295
- # For custom audio_processor, we reference the class directly.
296
- audio_processor_class = Gemma3AudioFeatureExtractor
297
 
298
- # image_processor_class = "AutoImageProcessor" # Example
299
- # tokenizer_class = "AutoTokenizer" # Example
300
 
301
  def __init__(
302
  self,
303
- tokenizer, # Tokenizer is essential
304
  audio_processor: Optional[Union[Gemma3AudioFeatureExtractor, Dict]] = None,
305
- image_processor=None, # Define further if used
306
  chat_template=None,
307
- image_seq_length: int = 256, # Default from user code
308
- # Parameters for calculating number of audio soft tokens in text prompt
309
- audio_prompt_compression_rate: int = 8,
310
  audio_prompt_qformer_rate: int = 1,
311
  audio_prompt_feat_stride: int = 1,
312
- audio_placeholder_token: str = "<|audio_placeholder|>", # Placeholder in user text
313
- audio_soft_token_str: str = "<audio_soft_token>", # The actual soft token string
314
  **kwargs
315
  ):
 
316
  if audio_processor is None:
317
- logger.info("Initializing Gemma3AudioFeatureExtractor with default parameters.")
318
  audio_processor = Gemma3AudioFeatureExtractor()
319
  elif isinstance(audio_processor, Dict):
320
  audio_processor = Gemma3AudioFeatureExtractor(**audio_processor)
321
- elif not isinstance(audio_processor, Gemma3AudioFeatureExtractor):
322
- raise TypeError(
323
- f"audio_processor must be an instance of Gemma3AudioFeatureExtractor or a config dict, got {type(audio_processor)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
- # Ensure tokenizer is provided and instantiated
326
- if tokenizer is None: # This check might be redundant if from_pretrained handles it
327
- raise ValueError("A tokenizer must be provided.")
328
- # if isinstance(tokenizer, str): # Basic loading, usually from_pretrained handles complex cases
329
- # tokenizer = AutoTokenizer.from_pretrained(tokenizer)
330
 
331
  super().__init__(
332
  image_processor=image_processor,
333
  audio_processor=audio_processor,
334
  tokenizer=tokenizer,
335
  chat_template=chat_template,
336
- **kwargs
337
  )
338
-
339
  self.image_seq_length = image_seq_length
340
- # Robustly get special image tokens from tokenizer
341
- self.image_token_id = getattr(tokenizer, "image_token_id",
342
- tokenizer.unk_token_id if hasattr(tokenizer, "unk_token_id") else None)
343
- self.boi_token = getattr(tokenizer, "boi_token", "<|image|>") # Using <|image|> as a more common BOI
344
- self.image_token = getattr(tokenizer, "image_token", "<|image|>")
345
- self.eoi_token = getattr(tokenizer, "eoi_token", "") # End of image, can be empty
346
 
347
  self.audio_placeholder_token = audio_placeholder_token
348
  self.audio_soft_token_str = audio_soft_token_str
349
-
350
- # Get ID for the audio soft token string; it must exist in the tokenizer
351
  self.audio_soft_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_soft_token_str)
352
- if self.audio_soft_token_id == self.tokenizer.unk_token_id:
353
- logger.warning(
354
- f"The audio soft token string '{self.audio_soft_token_str}' maps to UNK token. "
355
  "Ensure it is added to the tokenizer's vocabulary as a special token."
356
  )
357
- # User's original expected ID, for reference or potential validation
358
- # self.expected_audio_token_id = 262143
359
- # if self.audio_soft_token_id != self.expected_audio_token_id:
360
- # logger.warning(f"Assigned ID {self.audio_soft_token_id} for '{self.audio_soft_token_str}' does not match expected ID {self.expected_audio_token_id}.")
361
 
362
  self.full_image_sequence_str = f"\n\n{self.boi_token}{''.join([self.image_token] * self.image_seq_length)}{self.eoi_token}\n\n"
363
 
364
- # Store rates for calculating number of audio soft tokens for the prompt
365
  self.audio_prompt_compression_rate = audio_prompt_compression_rate
366
  self.audio_prompt_qformer_rate = audio_prompt_qformer_rate
367
  self.audio_prompt_feat_stride = audio_prompt_feat_stride
368
 
369
- def _merge_kwargs(self, ModelProcessorKwargs, tokenizer_init_kwargs, **kwargs_passed_to_call):
370
- # Use ModelProcessorKwargs (e.g. Gemma3DummyProcessorKwargs) to get default structures
371
- # This method was complex in user code, simplifying slightly for clarity
372
  final_kwargs = {}
373
  # Initialize with _defaults from the Kwargs class
374
- for modality_key, default_modality_kwargs in ModelProcessorKwargs._defaults.items():
 
 
375
  final_kwargs[modality_key] = default_modality_kwargs.copy()
376
 
377
  # Override with tokenizer's init_kwargs if they exist for a given key
378
  for modality_key, modality_dict in final_kwargs.items():
379
- for key in list(modality_dict.keys()): # Iterate over copy of keys
380
  if key in tokenizer_init_kwargs:
381
  modality_dict[key] = tokenizer_init_kwargs[key]
382
-
383
- # Override with kwargs passed directly to __call__ (e.g. kwargs['text_kwargs'])
384
- for modality_key, modality_dict_from_call in kwargs_passed_to_call.items():
385
- if modality_key in final_kwargs and isinstance(modality_dict_from_call, dict):
386
- final_kwargs[modality_key].update(modality_dict_from_call)
387
- elif modality_key not in final_kwargs and isinstance(modality_dict_from_call,
388
- dict): # For kwargs not in _defaults
389
- final_kwargs[modality_key] = modality_dict_from_call.copy()
390
-
391
- # Specific handling for text_kwargs as in user code
392
- if "text_kwargs" in final_kwargs:
393
- final_kwargs["text_kwargs"]["truncation"] = final_kwargs["text_kwargs"].get("truncation",
394
- False) # Keep user's if provided
395
- final_kwargs["text_kwargs"]["max_length"] = final_kwargs["text_kwargs"].get("max_length",
396
- DEFAULT_MAX_LENGTH)
397
- else: # Ensure text_kwargs exists
398
- final_kwargs["text_kwargs"] = {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH}
399
-
400
  return final_kwargs
401
 
402
  def _compute_audio_prompt_token_count(self, actual_mel_frames_count: int) -> int:
403
- """Calculates how many <audio_soft_token> to insert in the text prompt."""
404
- # Uses parameters specific to this processor for prompt engineering
405
  scaled_frames = actual_mel_frames_count * self.audio_prompt_feat_stride
406
  compressed_once = math.ceil(scaled_frames / self.audio_prompt_compression_rate)
407
  compressed_twice = math.ceil(compressed_once / self.audio_prompt_qformer_rate)
@@ -410,118 +371,109 @@ class Gemma3OmniProcessor(ProcessorMixin):
410
  def __call__(
411
  self,
412
  text: Union[str, List[str]] = None,
413
- images: Optional[Any] = None, # Type depends on image_processor
414
  audios: Optional[Union[AudioInput, List[AudioInput]]] = None,
415
- sampling_rate: Optional[int] = None, # For raw audio arrays passed to audio_processor
416
- return_tensors: Optional[Union[str, TensorType]] = None, # Default behavior based on HF
417
- **kwargs: Any # Using Any for Unpack[Gemma3ProcessorKwargs] as it's not fully defined
418
  ) -> BatchFeature:
419
-
420
  if text is None and images is None and audios is None:
421
  raise ValueError("Provide at least one of `text`, `images`, or `audios`.")
422
 
423
- # Determine final return_tensors type (passed explicitly, or from text_kwargs, or default)
424
- # The _merge_kwargs will get text_kwargs, but return_tensors can be a top-level arg.
425
- _text_kwargs_from_call = kwargs.get("text_kwargs", {})
426
- _rt_from_text_kwargs = _text_kwargs_from_call.get("return_tensors")
427
-
428
- final_rt = TensorType.PYTORCH # Default
429
- if return_tensors is not None:
430
- final_rt = return_tensors
431
- elif _rt_from_text_kwargs is not None:
432
- final_rt = _rt_from_text_kwargs
433
-
434
- # Get all kwargs merged (text_kwargs, images_kwargs, audio_kwargs)
435
- # Using Gemma3DummyProcessorKwargs as the source of _defaults structure
436
- merged_kwargs = self._merge_kwargs(
437
- Gemma3DummyProcessorKwargs,
438
  self.tokenizer.init_kwargs if hasattr(self.tokenizer, 'init_kwargs') else {},
439
- **kwargs
440
  )
 
 
 
 
 
 
441
 
442
- # Ensure text is a list of strings
443
  if text is None:
444
  num_samples = 0
445
  if images is not None:
446
- # Simplified: make_nested_list_of_images would be here for HF standard image proc.
447
- num_samples = len(images) if isinstance(images, list) and not isinstance(images[0], (float, int)) else 1
448
  elif audios is not None:
449
- num_samples = len(audios) if isinstance(audios, list) else 1
450
- text = [""] * num_samples if num_samples > 0 else [""] # Dummy text if only modality
451
-
 
452
  if isinstance(text, str):
453
  text = [text]
454
  if not (isinstance(text, list) and all(isinstance(t, str) for t in text)):
455
  raise ValueError("Input `text` must be a string or a list of strings.")
456
 
457
- # --- Image Processing (Simplified Placeholder) ---
458
- image_features = {}
459
  if images is not None and self.image_processor is not None:
460
  logger.info("Processing images...")
461
- # Actual image processing call would be here, e.g.:
462
- # image_features = self.image_processor(images, return_tensors=None, **merged_kwargs.get("images_kwargs", {}))
463
- # And text would be modified to include image tokens like self.full_image_sequence_str
464
- # For now, just a pass-through to show where it fits.
465
- pass # Replace with actual image processing and text modification logic
466
-
467
- # --- Audio Processing ---
468
- audio_features = {}
469
  if audios is not None and self.audio_processor is not None:
470
  logger.info("Processing audio...")
471
- audio_call_kwargs = merged_kwargs.get("audio_kwargs", {})
472
- if sampling_rate: # Pass sampling_rate if provided for raw arrays
473
- audio_call_kwargs["sampling_rate"] = sampling_rate
 
 
 
 
474
 
475
- # Get dict of numpy arrays/lists first from feature_extractor
476
- audio_features_np = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs)
477
- audio_features = audio_features_np # Store the dict
478
-
479
- # Modify text to include audio soft tokens
480
  new_text_with_audio = []
481
- # audio_attention_mask from feature extractor is (B, Max_T_mel)
482
- audio_sample_mel_lengths = audio_features_np["audio_attention_mask"].sum(
483
- axis=-1) # Get actual mel frames per sample
484
 
485
  for i, prompt in enumerate(text):
486
- num_soft_tokens = self._compute_audio_prompt_token_count(int(audio_sample_mel_lengths[i]))
487
  audio_token_sequence_str = self.audio_soft_token_str * num_soft_tokens
488
-
489
  if self.audio_placeholder_token in prompt:
490
  prompt = prompt.replace(self.audio_placeholder_token, audio_token_sequence_str, 1)
491
- else: # If no placeholder, append the audio tokens
492
- prompt += audio_token_sequence_str
493
  new_text_with_audio.append(prompt)
494
  text = new_text_with_audio
495
-
496
- # --- Text Tokenization ---
497
  logger.info("Tokenizing text...")
498
- text_call_kwargs = merged_kwargs.get("text_kwargs", {})
499
- # Ensure tokenizer gets lists/np arrays, BatchFeature will handle final tensor conversion
500
- text_features = self.tokenizer(text, return_tensors=None, **text_call_kwargs)
501
 
502
- # --- Create token_type_ids ---
503
- input_ids_list = text_features["input_ids"] # Should be list of lists of token IDs
504
- if not isinstance(input_ids_list[0], list): # Handle if tokenizer returns single list for single text
505
- input_ids_list = [input_ids_list]
 
 
506
 
507
  token_type_ids_list = []
508
  for ids_sample in input_ids_list:
509
- types = [0] * len(ids_sample) # 0 for text
510
  for j, token_id in enumerate(ids_sample):
511
  if self.image_token_id is not None and token_id == self.image_token_id:
512
- types[j] = 1 # 1 for image
513
- elif token_id == self.audio_soft_token_id: # Compare with the ID of the soft token string
514
- types[j] = 2 # 2 for audio
515
  token_type_ids_list.append(types)
516
- text_features["token_type_ids"] = token_type_ids_list
517
-
518
- # Combine all features
519
- combined_features = {**text_features}
520
- if image_features: # image_features is already a dict
521
- combined_features.update(image_features)
522
- if audio_features: # audio_features is already a dict from audio_processor
523
- combined_features.update(audio_features)
524
-
525
  return BatchFeature(data=combined_features, tensor_type=final_rt)
526
 
527
  def batch_decode(self, *args, **kwargs):
@@ -532,14 +484,11 @@ class Gemma3OmniProcessor(ProcessorMixin):
532
 
533
  @property
534
  def model_input_names(self) -> List[str]:
535
- """
536
- Defines the expected inputs for the model. Combines tokenizer, image_processor, and audio_processor inputs.
537
- """
538
  input_names = set(self.tokenizer.model_input_names + ["token_type_ids"])
539
  if self.image_processor is not None:
540
  input_names.update(self.image_processor.model_input_names)
541
  if self.audio_processor is not None:
542
- # Gemma3AudioFeatureExtractor produces "audio_values" and "audio_attention_mask"
543
- # Other keys like "audio_token_calc_sizes" are for internal use by processor, not model input
544
- input_names.update(["audio_values", "audio_attention_mask"])
545
  return list(input_names)
 
1
  import re
2
+ from typing import List, Optional, Union, Dict, Any
3
 
4
  import math
5
  import numpy as np
6
+ import scipy.signal
7
  import torch
8
  from torch.nn.utils.rnn import pad_sequence
9
 
 
15
 
16
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
17
  from transformers.feature_extraction_utils import BatchFeature
18
+ from transformers.processing_utils import ProcessorMixin, ProcessingKwargs
 
 
19
  from transformers.utils import TensorType, to_py_obj, logging
20
+ # For AutoImageProcessor, AutoTokenizer if needed for default loading
21
+ from transformers import AutoImageProcessor, AutoTokenizer
22
 
23
+
24
+ # Constants (as defined before)
25
  DEFAULT_SAMPLING_RATE = 16000
26
  DEFAULT_N_FFT = 512
27
+ DEFAULT_WIN_LENGTH = 400 # Will be n_fft if None in __init__
28
+ DEFAULT_HOP_LENGTH = 160 # Will be win_length // 4 if None in __init__
29
  DEFAULT_N_MELS = 80
30
+ DEFAULT_COMPRESSION_RATE = 4
31
+ DEFAULT_QFORMER_RATE = 2
32
+ DEFAULT_FEAT_STRIDE = 4
33
+ IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
34
+ AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
35
+ DEFAULT_MAX_LENGTH = 16384
36
  LOG_MEL_CLIP_EPSILON = 1e-5
37
 
38
  logger = logging.get_logger(__name__)
39
 
40
 
41
+ # create_mel_filterbank function (assuming it's correctly defined from previous response)
42
+ # ... (create_mel_filterbank function from the previous corrected response) ...
43
  def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
44
  fmax: Optional[float] = None) -> np.ndarray:
45
  """Create Mel filterbank for audio processing."""
 
48
  if fmin >= fmax:
49
  raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
50
 
51
+ def hz_to_mel(f: float) -> float: # Using HTK formula (as in librosa default)
52
  return 2595.0 * math.log10(1 + f / 700.0)
53
 
54
  def mel_to_hz(mel: float) -> float:
55
+ return 700.0 * (10**(mel / 2595.0) - 1)
56
 
57
  mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
58
  freq_points = mel_to_hz(mel_points)
59
+
 
60
  freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
61
+ bins = np.floor((n_fft / 2.0) * freq_points / (sampling_rate / 2.0)).astype(int)
62
+ bins = np.clip(bins, 0, n_fft // 2)
 
 
 
63
 
64
  filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
65
+ for m in range(n_mels):
66
  left, center, right = bins[m], bins[m + 1], bins[m + 2]
67
+
68
+ # Simplified triangle creation logic (more robust versions exist in libraries like librosa)
69
  if center > left:
70
+ filterbank[m, left:center+1] = (np.arange(left, center + 1) - left) / (center - left)
 
71
  if right > center:
72
+ filterbank[m, center:right+1] = (right - np.arange(center, right + 1)) / (right - center)
73
+ # Ensure peak is 1 if multiple points coincide at center (can happen with narrow filters/low resolution)
74
+ if left <= center <= right and filterbank[m,center] < 1.0 and (center > left or center < right) : #check if it's a valid point for a peak
75
+ # if filterbank[m,center] is not properly set to 1 by slopes (e.g. left==center or right==center)
76
+ filterbank[m,center] = 1.0
77
+ if left == center and right > center : # only falling slope
78
+ # Ensure it doesn't double-dip if already set
79
+ pass
80
+ elif right == center and left < center: # only rising slope
81
+ pass
 
 
 
 
 
 
 
 
 
82
 
 
83
 
84
+ return filterbank
85
 
86
+ # Gemma3AudioFeatureExtractor class (assuming it's correctly defined from previous response)
87
+ # ... (Gemma3AudioFeatureExtractor class from the previous corrected response) ...
88
  class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
89
+ model_input_names = ["audio_values", "audio_attention_mask"]
90
 
91
  def __init__(
92
  self,
93
  compression_rate: int = DEFAULT_COMPRESSION_RATE,
94
  qformer_rate: int = DEFAULT_QFORMER_RATE,
95
  feat_stride: int = DEFAULT_FEAT_STRIDE,
96
+ sampling_rate: int = DEFAULT_SAMPLING_RATE,
97
  n_fft: int = DEFAULT_N_FFT,
98
+ win_length: Optional[int] = None,
99
+ hop_length: Optional[int] = None,
100
  n_mels: int = DEFAULT_N_MELS,
101
  f_min: float = 0.0,
102
  f_max: Optional[float] = None,
103
+ padding_value: float = 0.0,
104
  **kwargs
105
  ):
 
106
  super().__init__(feature_size=n_mels, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
107
 
108
  self.compression_rate = compression_rate
109
  self.qformer_rate = qformer_rate
110
+ self.feat_stride = feat_stride
 
111
  self.n_fft = n_fft
112
  self.win_length = win_length if win_length is not None else n_fft
113
  self.hop_length = hop_length if hop_length is not None else self.win_length // 4
 
120
  f"win_length ({self.win_length}) is greater than n_fft ({self.n_fft}). "
121
  f"For FFT computation, the window will effectively be truncated or the signal zero-padded to n_fft length."
122
  )
123
+ self.window = scipy.signal.get_window("hann", self.win_length).astype(np.float32)
124
  self.mel_filterbank = create_mel_filterbank(
125
  self.sampling_rate, self.n_fft, self.n_mels, fmin=self.f_min, fmax=self.f_max
126
+ ).T
127
 
128
  def __call__(
129
  self,
130
  audios: Union[AudioInput, List[AudioInput]],
131
+ sampling_rate: Optional[int] = None,
132
  return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
133
  ) -> BatchFeature:
134
+
135
  if not isinstance(audios, list):
136
  audios = [audios]
137
 
138
  processed_mel_spectrograms: List[torch.Tensor] = []
139
  actual_mel_lengths: List[int] = []
140
+ downstream_sizes_for_token_calc: List[torch.Tensor] = []
 
 
141
  downstream_frames_scaled_for_token_calc: List[int] = []
142
 
143
  for audio_input_item in audios:
 
154
  "sampling_rate must be provided if audio inputs are raw numpy arrays or lists."
155
  )
156
  source_sr = sampling_rate
157
+ else:
158
  raise TypeError(
159
  f"Unsupported audio input type: {type(audio_input_item)}. "
160
  "This extractor expects np.ndarray, list of floats, or Tuple[np.ndarray, int indicating SR]."
161
  )
162
 
163
  processed_wav = self._preprocess_audio(current_wav_array, source_sr)
164
+ mel_spectrogram = self._compute_log_mel_spectrogram(processed_wav)
165
+
166
+ feature_tensor = torch.from_numpy(mel_spectrogram)
 
167
  processed_mel_spectrograms.append(feature_tensor)
168
+ actual_mel_lengths.append(feature_tensor.shape[0])
169
 
 
170
  downstream_sizes_for_token_calc.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
171
  downstream_frames_scaled_for_token_calc.append(feature_tensor.shape[0] * self.feat_stride)
172
 
 
 
173
  audio_values = pad_sequence(processed_mel_spectrograms, batch_first=True, padding_value=self.padding_value)
174
+ max_mel_len = audio_values.shape[1]
 
 
175
  lengths_tensor = torch.tensor(actual_mel_lengths, dtype=torch.long)
176
+ audio_attention_mask = torch.arange(max_mel_len).unsqueeze(0).expand(len(audios), -1) < lengths_tensor.unsqueeze(1)
177
+
 
 
178
  output_data = {
179
+ "audio_values": audio_values,
180
+ "audio_attention_mask": audio_attention_mask
181
  }
182
+
183
+ if downstream_sizes_for_token_calc:
184
+ output_data["audio_token_calc_sizes"] = torch.stack(downstream_sizes_for_token_calc)
185
+
 
 
 
186
  return BatchFeature(data=output_data, tensor_type=return_tensors)
187
 
188
  def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray:
189
  if wav.dtype not in [np.float32, np.float64]:
 
190
  if np.issubdtype(wav.dtype, np.integer):
191
  max_val = np.iinfo(wav.dtype).max
192
  wav = wav.astype(np.float32) / max_val
193
+ else:
194
  wav = wav.astype(np.float32)
195
+
196
  if wav.ndim > 1:
197
+ wav = wav.mean(axis=0)
198
+
199
  if source_sr != self.sampling_rate:
 
 
200
  gcd = math.gcd(self.sampling_rate, source_sr)
201
  up_factor = self.sampling_rate // gcd
202
  down_factor = source_sr // gcd
203
+ if up_factor != down_factor:
204
+ logger.info(f"Resampling audio from {source_sr} Hz to {self.sampling_rate} Hz.")
205
+ wav = scipy.signal.resample_poly(wav, up=up_factor, down=down_factor)
206
+
 
207
  norm_factor = np.abs(wav).max()
208
+ if norm_factor > 1e-9:
209
  wav = wav / norm_factor
210
  return wav
211
 
 
214
  padding = self.win_length - len(wav)
215
  wav = np.pad(wav, (0, padding), mode='constant', constant_values=0.0)
216
 
 
 
217
  num_frames = 1 + (len(wav) - self.win_length) // self.hop_length
218
  if num_frames <= 0:
219
+ logger.warning(f"Audio of length {len(wav)} is too short to produce frames with win_length {self.win_length} and hop_length {self.hop_length}. Returning empty mel spectrogram.")
 
220
  return np.zeros((0, self.n_mels), dtype=np.float32)
221
 
222
  frames = np.lib.stride_tricks.as_strided(
 
225
  strides=(wav.strides[0] * self.hop_length, wav.strides[0]),
226
  writeable=False
227
  )
228
+
229
  windowed_frames = frames * self.window
230
+ stft_matrix = np.fft.rfft(windowed_frames, n=self.n_fft, axis=-1)
231
+ powers = np.abs(stft_matrix)**2
232
+ mel_spectrogram = np.dot(powers, self.mel_filterbank)
 
 
 
233
  mel_spectrogram = np.clip(mel_spectrogram, LOG_MEL_CLIP_EPSILON, None)
234
  log_mel_spectrogram = np.log(mel_spectrogram)
235
+
236
  return log_mel_spectrogram.astype(np.float32)
237
 
238
  def _calculate_embed_length(self, frame_count: int) -> int:
 
 
239
  compressed = math.ceil(frame_count / self.compression_rate)
240
  return math.ceil(compressed / self.qformer_rate)
241
 
242
+ class Gemma3DummyProcessorKwargs(ProcessingKwargs, total=False): # Dummy for testing structure
 
 
 
243
  images_kwargs: Dict[str, Any]
244
  audio_kwargs: Dict[str, Any]
245
+ text_kwargs: Dict[str, Any]
246
  _defaults = {
247
  "text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH},
248
  "images_kwargs": {},
249
  "audio_kwargs": {}
250
  }
251
 
 
252
  class Gemma3OmniProcessor(ProcessorMixin):
253
  attributes = ["image_processor", "audio_processor", "tokenizer"]
254
+ # Define class attributes for ProcessorMixin to find/use them
255
+ image_processor_class = "AutoImageProcessor" # Or the specific class string if not auto
256
+ audio_processor_class = Gemma3AudioFeatureExtractor # Correctly points to your custom class
257
+ tokenizer_class = "AutoTokenizer" # Or the specific class string
258
 
259
+ # valid_kwargs was in user's code, its role depends on ProcessorMixin internal usage
260
+ valid_kwargs = ["chat_template", "image_seq_length"]
261
 
262
  def __init__(
263
  self,
264
+ tokenizer,
265
  audio_processor: Optional[Union[Gemma3AudioFeatureExtractor, Dict]] = None,
266
+ image_processor = None,
267
  chat_template=None,
268
+ image_seq_length: int = 256,
269
+ audio_prompt_compression_rate: int = 8,
 
270
  audio_prompt_qformer_rate: int = 1,
271
  audio_prompt_feat_stride: int = 1,
272
+ audio_placeholder_token: str = "<|audio_placeholder|>",
273
+ audio_soft_token_str: str = "<audio_soft_token>",
274
  **kwargs
275
  ):
276
+ # Instantiate audio_processor if config dict is passed or if None (use defaults)
277
  if audio_processor is None:
278
+ logger.info("Initializing Gemma3AudioFeatureExtractor with default parameters for Gemma3OmniProcessor.")
279
  audio_processor = Gemma3AudioFeatureExtractor()
280
  elif isinstance(audio_processor, Dict):
281
  audio_processor = Gemma3AudioFeatureExtractor(**audio_processor)
282
+ elif not isinstance(audio_processor, Gemma3AudioFeatureExtractor): # Check type if instance is passed
283
+ raise TypeError(f"audio_processor must be an instance of Gemma3AudioFeatureExtractor or a config dict, got {type(audio_processor)}")
284
+
285
+ # Handle image_processor similarly if it can be None or a dict
286
+ if image_processor is None and self.image_processor_class:
287
+ # This is a basic way; from_pretrained usually handles complex loading
288
+ if isinstance(self.image_processor_class, str) and self.image_processor_class == "AutoImageProcessor":
289
+ logger.info(f"Attempting to load a default {self.image_processor_class}. This might require a default model name or fail.")
290
+ # image_processor = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch32") # Example default
291
+ # else if self.image_processor_class is an actual class, instantiate it.
292
+ elif isinstance(image_processor, Dict):
293
+ # image_processor = AutoImageProcessor.from_config(config_class(**image_processor)) # Example
294
+ pass # Actual instantiation from dict would be more complex
295
+
296
+ # Ensure tokenizer is an instantiated object
297
+ if isinstance(tokenizer, str): # If tokenizer is a string (model name/path)
298
+ logger.info(f"Loading tokenizer from {tokenizer}")
299
+ # tokenizer = AutoTokenizer.from_pretrained(tokenizer) # This is how it's usually done
300
+ elif tokenizer is None:
301
+ raise ValueError("A tokenizer instance or identifier must be provided.")
302
 
 
 
 
 
 
303
 
304
  super().__init__(
305
  image_processor=image_processor,
306
  audio_processor=audio_processor,
307
  tokenizer=tokenizer,
308
  chat_template=chat_template,
309
+ **kwargs # Pass other kwargs to super
310
  )
311
+
312
  self.image_seq_length = image_seq_length
313
+ self.image_token_id = getattr(self.tokenizer, "image_token_id", self.tokenizer.unk_token_id if hasattr(self.tokenizer, "unk_token_id") else None)
314
+ self.boi_token = getattr(self.tokenizer, "boi_token", "<|image|>")
315
+ self.image_token = getattr(self.tokenizer, "image_token", "<|image|>")
316
+ self.eoi_token = getattr(self.tokenizer, "eoi_token", "")
 
 
317
 
318
  self.audio_placeholder_token = audio_placeholder_token
319
  self.audio_soft_token_str = audio_soft_token_str
320
+
 
321
  self.audio_soft_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_soft_token_str)
322
+ if self.audio_soft_token_id == self.tokenizer.unk_token_id: # Check if UNK
323
+ logger.warning(
324
+ f"The audio soft token string '{self.audio_soft_token_str}' maps to UNK token (ID: {self.audio_soft_token_id}). "
325
  "Ensure it is added to the tokenizer's vocabulary as a special token."
326
  )
 
 
 
 
327
 
328
  self.full_image_sequence_str = f"\n\n{self.boi_token}{''.join([self.image_token] * self.image_seq_length)}{self.eoi_token}\n\n"
329
 
 
330
  self.audio_prompt_compression_rate = audio_prompt_compression_rate
331
  self.audio_prompt_qformer_rate = audio_prompt_qformer_rate
332
  self.audio_prompt_feat_stride = audio_prompt_feat_stride
333
 
334
+
335
+ def _merge_kwargs(self, KwargsClassWithDefaults, tokenizer_init_kwargs, **kwargs_passed_to_call):
 
336
  final_kwargs = {}
337
  # Initialize with _defaults from the Kwargs class
338
+ # Ensure KwargsClassWithDefaults has a _defaults attribute
339
+ _defaults = getattr(KwargsClassWithDefaults, "_defaults", {})
340
+ for modality_key, default_modality_kwargs in _defaults.items():
341
  final_kwargs[modality_key] = default_modality_kwargs.copy()
342
 
343
  # Override with tokenizer's init_kwargs if they exist for a given key
344
  for modality_key, modality_dict in final_kwargs.items():
345
+ for key in list(modality_dict.keys()):
346
  if key in tokenizer_init_kwargs:
347
  modality_dict[key] = tokenizer_init_kwargs[key]
348
+
349
+ # Override with kwargs passed directly to __call__
350
+ for modality_key_from_call, modality_dict_from_call in kwargs_passed_to_call.items():
351
+ if modality_key_from_call in final_kwargs and isinstance(modality_dict_from_call, dict):
352
+ final_kwargs[modality_key_from_call].update(modality_dict_from_call)
353
+ # If a new modality_kwargs (e.g., "video_kwargs") is passed, add it
354
+ elif modality_key_from_call not in final_kwargs and isinstance(modality_dict_from_call, dict):
355
+ final_kwargs[modality_key_from_call] = modality_dict_from_call.copy()
356
+
357
+ # Specific handling for text_kwargs
358
+ if "text_kwargs" not in final_kwargs:
359
+ final_kwargs["text_kwargs"] = {} # Ensure it exists
360
+ final_kwargs["text_kwargs"]["truncation"] = final_kwargs["text_kwargs"].get("truncation", False)
361
+ final_kwargs["text_kwargs"]["max_length"] = final_kwargs["text_kwargs"].get("max_length", DEFAULT_MAX_LENGTH)
362
+
 
 
 
363
  return final_kwargs
364
 
365
  def _compute_audio_prompt_token_count(self, actual_mel_frames_count: int) -> int:
 
 
366
  scaled_frames = actual_mel_frames_count * self.audio_prompt_feat_stride
367
  compressed_once = math.ceil(scaled_frames / self.audio_prompt_compression_rate)
368
  compressed_twice = math.ceil(compressed_once / self.audio_prompt_qformer_rate)
 
371
  def __call__(
372
  self,
373
  text: Union[str, List[str]] = None,
374
+ images: Optional[Any] = None,
375
  audios: Optional[Union[AudioInput, List[AudioInput]]] = None,
376
+ sampling_rate: Optional[int] = None,
377
+ return_tensors: Optional[Union[str, TensorType]] = None,
378
+ **kwargs: Any
379
  ) -> BatchFeature:
380
+
381
  if text is None and images is None and audios is None:
382
  raise ValueError("Provide at least one of `text`, `images`, or `audios`.")
383
 
384
+ # Determine final return_tensors strategy
385
+ # Priority: 1. Explicit return_tensors, 2. from text_kwargs in **kwargs, 3. Default (PT)
386
+ final_rt = return_tensors
387
+ merged_call_kwargs = self._merge_kwargs(
388
+ Gemma3DummyProcessorKwargs, # Using dummy for _defaults structure
 
 
 
 
 
 
 
 
 
 
389
  self.tokenizer.init_kwargs if hasattr(self.tokenizer, 'init_kwargs') else {},
390
+ **kwargs
391
  )
392
+
393
+ if final_rt is None: # If not passed directly to __call__
394
+ final_rt = merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", TensorType.PYTORCH)
395
+ else: # If passed directly, remove from text_kwargs to avoid conflict
396
+ merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", None)
397
+
398
 
 
399
  if text is None:
400
  num_samples = 0
401
  if images is not None:
402
+ _images_list = images if isinstance(images, list) and (not images or not isinstance(images[0], (int, float))) else [images]
403
+ num_samples = len(_images_list)
404
  elif audios is not None:
405
+ _audios_list = audios if isinstance(audios, list) else [audios]
406
+ num_samples = len(_audios_list)
407
+ text = [""] * num_samples if num_samples > 0 else [""]
408
+
409
  if isinstance(text, str):
410
  text = [text]
411
  if not (isinstance(text, list) and all(isinstance(t, str) for t in text)):
412
  raise ValueError("Input `text` must be a string or a list of strings.")
413
 
414
+ image_features_dict = {}
 
415
  if images is not None and self.image_processor is not None:
416
  logger.info("Processing images...")
417
+ # image_features_dict = self.image_processor(images, return_tensors=None, **merged_call_kwargs.get("images_kwargs", {}))
418
+ # Simplified: Actual image token replacement logic for `text` would go here.
419
+ # text = self._handle_image_text_replacement(text, images, image_features_dict)
420
+ pass
421
+
422
+
423
+ audio_features_dict = {}
 
424
  if audios is not None and self.audio_processor is not None:
425
  logger.info("Processing audio...")
426
+ audio_call_kwargs = merged_call_kwargs.get("audio_kwargs", {})
427
+ if sampling_rate:
428
+ audio_call_kwargs["sampling_rate"] = sampling_rate
429
+
430
+ # audio_processor.__call__ returns BatchFeature, we need its .data attribute
431
+ audio_features_batch_feature = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs)
432
+ audio_features_dict = audio_features_batch_feature.data # Get the dict
433
 
 
 
 
 
 
434
  new_text_with_audio = []
435
+ # audio_attention_mask shape is (B, Max_T_mel)
436
+ audio_sample_mel_lengths = to_py_obj(audio_features_dict["audio_attention_mask"].sum(axis=-1))
 
437
 
438
  for i, prompt in enumerate(text):
439
+ num_soft_tokens = self._compute_audio_prompt_token_count(audio_sample_mel_lengths[i])
440
  audio_token_sequence_str = self.audio_soft_token_str * num_soft_tokens
441
+
442
  if self.audio_placeholder_token in prompt:
443
  prompt = prompt.replace(self.audio_placeholder_token, audio_token_sequence_str, 1)
444
+ else:
445
+ prompt += audio_token_sequence_str
446
  new_text_with_audio.append(prompt)
447
  text = new_text_with_audio
448
+
 
449
  logger.info("Tokenizing text...")
450
+ text_call_kwargs = merged_call_kwargs.get("text_kwargs", {})
451
+ text_features_dict = self.tokenizer(text, return_tensors=None, **text_call_kwargs)
 
452
 
453
+ input_ids_list = text_features_dict["input_ids"]
454
+ if not isinstance(input_ids_list, list) or not (input_ids_list and isinstance(input_ids_list[0], list)):
455
+ if isinstance(input_ids_list, (torch.Tensor, np.ndarray)):
456
+ input_ids_list = to_py_obj(input_ids_list) # Convert tensor/np.array to list of lists
457
+ elif isinstance(input_ids_list, list) and (not input_ids_list or isinstance(input_ids_list[0], int)):
458
+ input_ids_list = [input_ids_list]
459
 
460
  token_type_ids_list = []
461
  for ids_sample in input_ids_list:
462
+ types = [0] * len(ids_sample)
463
  for j, token_id in enumerate(ids_sample):
464
  if self.image_token_id is not None and token_id == self.image_token_id:
465
+ types[j] = 1
466
+ elif token_id == self.audio_soft_token_id:
467
+ types[j] = 2
468
  token_type_ids_list.append(types)
469
+ text_features_dict["token_type_ids"] = token_type_ids_list
470
+
471
+ combined_features = {**text_features_dict}
472
+ if image_features_dict:
473
+ combined_features.update(image_features_dict)
474
+ if audio_features_dict:
475
+ combined_features.update(audio_features_dict)
476
+
 
477
  return BatchFeature(data=combined_features, tensor_type=final_rt)
478
 
479
  def batch_decode(self, *args, **kwargs):
 
484
 
485
  @property
486
  def model_input_names(self) -> List[str]:
 
 
 
487
  input_names = set(self.tokenizer.model_input_names + ["token_type_ids"])
488
  if self.image_processor is not None:
489
  input_names.update(self.image_processor.model_input_names)
490
  if self.audio_processor is not None:
491
+ # From Gemma3AudioFeatureExtractor's output_data keys
492
+ input_names.update(["audio_values", "audio_attention_mask"])
493
+ # "audio_token_calc_sizes" is internal to processor, not model.
494
  return list(input_names)