voidful commited on
Commit
5fc5a97
·
verified ·
1 Parent(s): 32e0cd2

Update processing_gemma3_omni.py

Browse files
Files changed (1) hide show
  1. processing_gemma3_omni.py +187 -439
processing_gemma3_omni.py CHANGED
@@ -1,16 +1,16 @@
1
  import re
2
- from typing import List, Optional, Union, Dict, Any, Tuple
3
 
4
  import math
5
  import numpy as np
6
  import scipy.signal
7
  import torch
8
  from torch.nn.utils.rnn import pad_sequence
9
- from transformers.audio_utils import AudioInput # type: ignore
10
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
11
  from transformers.feature_extraction_utils import BatchFeature
12
  from transformers.image_utils import make_nested_list_of_images
13
- from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs
14
  from transformers.utils import TensorType, to_py_obj, logging
15
 
16
  # Constants
@@ -25,7 +25,6 @@ DEFAULT_FEAT_STRIDE = 4
25
  IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
26
  AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
27
  DEFAULT_MAX_LENGTH = 16384
28
- LOG_MEL_CLIP_EPSILON = 1e-5
29
 
30
  logger = logging.get_logger(__name__)
31
 
@@ -33,41 +32,25 @@ logger = logging.get_logger(__name__)
33
  def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
34
  fmax: Optional[float] = None) -> np.ndarray:
35
  """Create Mel filterbank for audio processing."""
36
- fmax = fmax or sampling_rate / 2.0
37
 
38
  def hz_to_mel(f: float) -> float:
39
  return 1127.0 * math.log(1 + f / 700.0)
40
 
41
- if fmin >= fmax:
42
- raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
43
-
44
  mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
45
  freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1)
46
- freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
47
  bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
48
- bins = np.clip(bins, 0, n_fft // 2)
49
 
50
  filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
51
- for m_idx in range(n_mels):
52
- left, center, right = bins[m_idx], bins[m_idx + 1], bins[m_idx + 2]
53
-
54
- # Robust triangular filter creation from your version
55
- # (small adjustment to ensure slopes are only added if points are distinct)
56
- if center > left:
57
- filterbank[m_idx, left:center] = (np.arange(left, center) - left) / (center - left)
58
- if right > center:
59
- filterbank[m_idx, center:right] = (right - np.arange(center, right)) / (right - center)
60
- # Ensure peak is 1.0 if center is a valid point, particularly if left=center or center=right
61
- # This covers the case where a slope might not set the peak to 1 due to integer arithmetic.
62
- if left <= center <= right and ((center > left and center <= right) or (center < right and center >= left)):
63
- filterbank[m_idx, center] = 1.0
64
 
65
  return filterbank
66
 
67
 
68
  class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
69
- model_input_names = ["audio_values", "audio_attention_mask"]
70
-
71
  def __init__(
72
  self,
73
  compression_rate: int = DEFAULT_COMPRESSION_RATE,
@@ -75,182 +58,89 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
75
  feat_stride: int = DEFAULT_FEAT_STRIDE,
76
  sampling_rate: int = DEFAULT_SAMPLING_RATE,
77
  n_fft: int = DEFAULT_N_FFT,
78
- win_length: Optional[int] = None,
79
- hop_length: Optional[int] = None,
80
  n_mels: int = DEFAULT_N_MELS,
81
- f_min: float = 0.0,
82
- f_max: Optional[float] = None,
83
- padding_value: float = 0.0,
84
  **kwargs
85
  ):
86
- # Pop these before super().__init__ as they might conflict if also in kwargs
87
- # and super() doesn't expect them, or if super() expects them but under different names.
88
- # However, feature_size, sampling_rate, padding_value ARE arguments for SequenceFeatureExtractor.
89
- # So, ensure they are passed correctly.
90
- _feature_size = n_mels
91
- _sampling_rate = sampling_rate
92
- _padding_value = padding_value
93
-
94
- # Remove them from kwargs if they were also passed via kwargs to avoid duplicate argument error
95
  kwargs.pop("feature_size", None)
96
  kwargs.pop("sampling_rate", None)
97
  kwargs.pop("padding_value", None)
98
 
99
- _win_length = win_length if win_length is not None else n_fft
100
- _hop_length = hop_length if hop_length is not None else _win_length // 4
101
-
102
  super().__init__(
103
- feature_size=_feature_size,
104
- sampling_rate=_sampling_rate,
105
- padding_value=_padding_value,
106
  **kwargs
107
  )
108
 
109
  self.compression_rate = compression_rate
110
  self.qformer_rate = qformer_rate
111
  self.feat_stride = feat_stride
112
- # self.sampling_rate is set by super()
113
 
 
 
114
  self.n_fft = n_fft
115
- self.win_length = _win_length
116
- self.hop_length = _hop_length
117
- self.n_mels = n_mels
118
- self.f_min = f_min
119
- self.f_max = f_max
120
-
121
- if self.win_length > self.n_fft:
122
- logger.warning(
123
- f"win_length ({self.win_length}) is greater than n_fft ({self.n_fft}). "
124
- "Window will be applied, then data will be zero-padded/truncated to n_fft by np.fft.rfft."
125
- )
126
- self.window = np.hamming(self.win_length).astype(np.float32)
127
- self.mel_filterbank = create_mel_filterbank(
128
- self.sampling_rate, self.n_fft, self.n_mels, fmin=self.f_min, fmax=self.f_max
129
- ).T
130
 
131
  def __call__(
132
  self,
133
- audios: Union[AudioInput, List[AudioInput]],
134
- sampling_rate: Optional[int] = None,
135
  return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
136
  ) -> BatchFeature:
 
137
 
138
- if not isinstance(audios, list):
139
- audios = [audios]
140
-
141
- processed_mels: List[torch.Tensor] = []
142
- actual_mel_lengths: List[int] = []
143
- sizes_for_embed_length: List[torch.Tensor] = []
144
- frames_scaled_by_feat_stride: List[int] = []
145
-
146
- for audio_item in audios:
147
- current_wav: np.ndarray
148
- source_sr: int
149
-
150
- if isinstance(audio_item, tuple) and len(audio_item) == 2 and isinstance(audio_item[1], int):
151
- current_wav, source_sr = audio_item
152
- current_wav = np.asarray(current_wav, dtype=np.float32)
153
- elif isinstance(audio_item, (np.ndarray, list)):
154
- current_wav = np.asarray(audio_item, dtype=np.float32)
155
- if sampling_rate is None:
156
- raise ValueError(
157
- "sampling_rate must be provided if audio inputs are raw numpy arrays or lists without sr."
158
- )
159
- source_sr = sampling_rate
160
- else:
161
- raise TypeError(
162
- f"Unsupported audio input type: {type(audio_item)}. "
163
- "Expected np.ndarray, list of floats, or Tuple[np.ndarray, int]."
164
- )
165
-
166
- processed_wav_array = self._preprocess_audio(current_wav, source_sr)
167
- mel_spectrogram = self._compute_log_mel_spectrogram(processed_wav_array)
168
 
169
- feature_tensor = torch.from_numpy(mel_spectrogram)
170
- processed_mels.append(feature_tensor)
171
- actual_mel_lengths.append(feature_tensor.shape[0])
172
 
173
- sizes_for_embed_length.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
174
- frames_scaled_by_feat_stride.append(feature_tensor.shape[0] * self.feat_stride)
175
-
176
- audio_embeds = pad_sequence(processed_mels, batch_first=True, padding_value=self.padding_value)
177
-
178
- max_t_mel_in_batch = audio_embeds.shape[1]
179
-
180
- attention_mask = torch.zeros(len(audios), max_t_mel_in_batch,
181
- dtype=torch.bool) # Device handled by BatchFeature
182
- for i, length in enumerate(actual_mel_lengths):
183
- attention_mask[i, :length] = True
184
 
185
  output_data = {
186
  "audio_values": audio_embeds,
187
- "audio_attention_mask": attention_mask
188
  }
189
-
190
- if sizes_for_embed_length:
191
- output_data["audio_values_sizes"] = torch.stack(sizes_for_embed_length)
192
-
193
- logger.debug(
194
- f"Gemma3AudioFeatureExtractor: Output 'audio_values' shape: {output_data['audio_values'].shape}") # Verify output
195
 
196
  return BatchFeature(data=output_data, tensor_type=return_tensors)
197
 
198
  def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray:
199
- if wav.dtype not in [np.float32, np.float64]:
200
- if np.issubdtype(wav.dtype, np.integer):
201
- max_val = np.iinfo(wav.dtype).max if wav.size > 0 else 1.0
202
- wav = wav.astype(np.float32) / max_val
203
- else:
204
- wav = wav.astype(np.float32)
205
- elif wav.dtype == np.float64:
206
- wav = wav.astype(np.float32)
207
-
208
  if wav.ndim > 1:
209
  wav = wav.mean(axis=0)
210
-
211
  if source_sr != self.sampling_rate:
212
- # logger.info(f"Resampling audio from {source_sr} Hz to {self.sampling_rate} Hz.") # logger might not be defined if this class is used standalone
213
- common_divisor = math.gcd(self.sampling_rate, source_sr)
214
- up_factor = self.sampling_rate // common_divisor
215
- down_factor = source_sr // common_divisor
216
- if up_factor != down_factor:
217
- wav = scipy.signal.resample_poly(wav, up=up_factor, down=down_factor)
218
-
219
- max_abs_val = np.abs(wav).max()
220
- if max_abs_val > 1e-7:
221
- wav = wav / max_abs_val
222
- return wav
223
 
224
  def _compute_log_mel_spectrogram(self, wav: np.ndarray) -> np.ndarray:
225
- if len(wav) < self.win_length:
226
- padding = self.win_length - len(wav)
227
- wav = np.pad(wav, (0, padding), mode='constant', constant_values=0.0)
228
-
229
- if len(wav) >= self.win_length:
230
- num_frames = 1 + (len(wav) - self.win_length) // self.hop_length
231
- else:
232
- num_frames = 0
233
-
234
- if num_frames <= 0:
235
- # logger.warning(...) # logger might not be defined
236
- return np.zeros((0, self.n_mels), dtype=np.float32)
237
-
238
- frames_view = np.lib.stride_tricks.as_strided(
239
  wav,
240
- shape=(num_frames, self.win_length),
241
- strides=(wav.strides[0] * self.hop_length, wav.strides[0]),
242
  writeable=False
243
- )
244
- frames_data = frames_view.copy()
245
- frames_data *= self.window
246
 
247
- spectrum = np.fft.rfft(frames_data, n=self.n_fft, axis=-1).astype(np.complex64)
248
  power = np.abs(spectrum) ** 2
249
  mel_spectrogram = np.dot(power, self.mel_filterbank)
250
- mel_spectrogram = np.clip(mel_spectrogram, LOG_MEL_CLIP_EPSILON, None)
251
- log_mel_spectrogram = np.log(mel_spectrogram)
252
-
253
- return log_mel_spectrogram.astype(np.float32)
254
 
255
  def _calculate_embed_length(self, frame_count: int) -> int:
256
  compressed = math.ceil(frame_count / self.compression_rate)
@@ -266,9 +156,8 @@ class Gemma3ImagesKwargs(ImagesKwargs):
266
 
267
 
268
  class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
269
- images_kwargs: Optional[Dict[str, Any]] = None
270
- audio_kwargs: Optional[Dict[str, Any]] = None
271
- text_kwargs: Optional[Dict[str, Any]] = None
272
  _defaults = {
273
  "text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH},
274
  "images_kwargs": {},
@@ -279,23 +168,38 @@ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
279
  class Gemma3OmniProcessor(ProcessorMixin):
280
  attributes = ["image_processor", "audio_processor", "tokenizer"]
281
  valid_kwargs = ["chat_template", "image_seq_length"]
282
-
283
- # --- CRITICAL FIX: Use STRING names for auto-loading by ProcessorMixin ---
284
  image_processor_class = "AutoImageProcessor"
285
- audio_processor_class = "AutoFeatureExtractor" # Must match the class name string
286
  tokenizer_class = "AutoTokenizer"
287
 
288
  def __init__(
289
  self,
290
- image_processor=None,
291
- audio_processor=None,
292
- tokenizer=None,
293
  chat_template=None,
294
  image_seq_length: int = 256,
295
- **kwargs # Catch-all for other potential superclass args or future additions
296
  ):
297
- # ProcessorMixin.__init__ handles instantiation of image_processor, audio_processor, tokenizer
298
- # if they are None, using the *_class attributes.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  super().__init__(
300
  image_processor=image_processor,
301
  audio_processor=audio_processor,
@@ -304,281 +208,136 @@ class Gemma3OmniProcessor(ProcessorMixin):
304
  **kwargs
305
  )
306
 
307
- # Attributes dependent on an instantiated tokenizer.
308
- # self.tokenizer should be populated by super().__init__ by this point.
309
- self.image_seq_length = image_seq_length
310
- if self.tokenizer is not None:
311
- self.image_token_id = getattr(self.tokenizer, "image_token_id",
312
- self.tokenizer.unk_token_id if hasattr(self.tokenizer,
313
- "unk_token_id") else None)
314
- self.boi_token = getattr(self.tokenizer, "boi_token", "<UNUSED_BOI>")
315
- self.image_token = getattr(self.tokenizer, "image_token", "<UNUSED_IMG_TOKEN>")
316
- self.eoi_token = getattr(self.tokenizer, "eoi_token", "<UNUSED_EOI>")
317
-
318
- # User's original audio token attributes
319
- self.audio_token_str_from_user_code = "<audio_soft_token>" # From user's original code
320
- # self.expected_audio_token_id = 262143 # User's reference
321
-
322
- self.audio_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_token_str_from_user_code)
323
- # User's original warning logic for audio_token_id
324
- # if self.audio_token_id != self.expected_audio_token_id: # Comparing to a fixed ID
325
- # logger.warning(f"Assigned ID {self.audio_token_id} for '{self.audio_token_str_from_user_code}' does not match expected ID {self.expected_audio_token_id}.")
326
- if hasattr(self.tokenizer, "unk_token_id") and self.audio_token_id == self.tokenizer.unk_token_id:
327
- logger.warning(
328
- f"Audio token '{self.audio_token_str_from_user_code}' not found in tokenizer, maps to UNK. Ensure it's added as a special token.")
329
-
330
- self.full_image_sequence = f"\n\n{self.boi_token}{''.join([self.image_token] * image_seq_length)}{self.eoi_token}\n\n"
331
- else:
332
- # This case should ideally not happen if from_pretrained works correctly.
333
- logger.error(
334
- "Gemma3OmniProcessor initialized, but tokenizer is None. Token-dependent attributes will be missing or use placeholders.")
335
- self.image_token_id = None
336
- self.boi_token = "<UNUSED_BOI>"
337
- self.image_token = "<UNUSED_IMG_TOKEN>"
338
- self.eoi_token = "<UNUSED_EOI>"
339
- self.audio_token_str_from_user_code = "<audio_soft_token>"
340
- self.audio_token_id = -1
341
- self.full_image_sequence = ""
342
-
343
- # These are parameters for this processor's logic of determining audio token sequence length for prompts
344
- # They were fixed values in user's original __init__
345
- self.prompt_audio_compression_rate = 8
346
- self.prompt_audio_qformer_compression_rate = 1
347
- self.prompt_audio_feat_stride = 1
348
-
349
- def _merge_kwargs(self, KwargsClassWithDefaults, tokenizer_init_kwargs, **kwargs_from_call):
350
- final_kwargs = {}
351
- _defaults = getattr(KwargsClassWithDefaults, "_defaults", {})
352
- if not isinstance(_defaults, dict): _defaults = {}
353
-
354
- for modality_key, default_modality_kwargs in _defaults.items():
355
- final_kwargs[modality_key] = default_modality_kwargs.copy()
356
-
357
- for modality_key_in_call, modality_kwargs_in_call in kwargs_from_call.items():
358
- if modality_key_in_call in final_kwargs:
359
- if isinstance(modality_kwargs_in_call, dict):
360
- final_kwargs[modality_key_in_call].update(modality_kwargs_in_call)
361
- elif isinstance(modality_kwargs_in_call, dict):
362
- final_kwargs[modality_key_in_call] = modality_kwargs_in_call.copy()
363
-
364
- for modality_key in final_kwargs:
365
- modality_dict = final_kwargs[modality_key]
366
- if isinstance(modality_dict, dict) and self.tokenizer is not None: # Check tokenizer exists
367
- for key_in_mod_dict in list(modality_dict.keys()):
368
- if key_in_mod_dict in tokenizer_init_kwargs:
369
- value = (
370
- getattr(self.tokenizer, key_in_mod_dict)
371
- if hasattr(self.tokenizer, key_in_mod_dict)
372
- else tokenizer_init_kwargs[key_in_mod_dict]
373
- )
374
- modality_dict[key_in_mod_dict] = value
375
 
376
- if "text_kwargs" not in final_kwargs:
377
- final_kwargs["text_kwargs"] = {}
378
- final_kwargs["text_kwargs"]["truncation"] = final_kwargs["text_kwargs"].get("truncation", False)
379
- final_kwargs["text_kwargs"]["max_length"] = final_kwargs["text_kwargs"].get("max_length", DEFAULT_MAX_LENGTH)
380
 
381
- return final_kwargs
 
 
 
382
 
383
- def _compute_audio_embed_size(self, audio_mel_frames: int) -> int:
384
- # Using processor's parameters for calculating number of special tokens in text prompt
385
- scaled_frames = audio_mel_frames * self.prompt_audio_feat_stride
386
- result = math.ceil(scaled_frames / self.prompt_audio_compression_rate)
387
- return math.ceil(result / self.prompt_audio_qformer_rate)
388
 
389
  def __call__(
390
  self,
391
- text: Union[str, List[str]] = None,
392
- images: Optional[Any] = None,
393
- audios: Optional[Union[AudioInput, List[AudioInput]]] = None,
394
- sampling_rate: Optional[int] = None,
395
- return_tensors: Optional[Union[str, TensorType]] = None,
396
- **kwargs: Any
397
  ) -> BatchFeature:
398
- if text is None and images is None and audios is None:
399
- raise ValueError("Provide at least one of `text`, `images`, or `audios`.")
400
 
401
- final_rt = return_tensors
402
- merged_call_kwargs = self._merge_kwargs(
403
- Gemma3ProcessorKwargs, # Use the defined Kwargs class
404
- self.tokenizer.init_kwargs if hasattr(self.tokenizer, 'init_kwargs') else {},
405
  **kwargs
406
  )
407
 
408
- if final_rt is None:
409
- final_rt = merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", TensorType.PYTORCH)
410
- else:
411
- merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", None)
412
-
413
- if text is None:
414
- num_samples = 0
415
- if images is not None:
416
- _images_list = images if isinstance(images, list) and (
417
- not images or not isinstance(images[0], (int, float))) else [images]
418
- num_samples = len(_images_list)
419
- elif audios is not None:
420
- _audios_list = audios if isinstance(audios, list) else [audios]
421
- num_samples = len(_audios_list)
422
- text = [""] * num_samples if num_samples > 0 else [""]
423
-
424
  if isinstance(text, str):
425
  text = [text]
426
- if not (isinstance(text, list) and all(isinstance(t, str) for t in text)):
427
- raise ValueError("Input `text` must be a string or a list of strings.")
428
 
429
- image_features_dict = {}
430
- # --- Image Processing (User's original structure, with safety for image_processor) ---
431
  if images is not None:
432
- if self.image_processor is None:
433
- raise ValueError("Images were provided, but `self.image_processor` is not set.")
434
  batched_images = make_nested_list_of_images(images)
435
- _img_proc_output = self.image_processor(batched_images, return_tensors=None,
436
- **merged_call_kwargs.get("images_kwargs", {}))
437
- image_features_dict = _img_proc_output.data if isinstance(_img_proc_output,
438
- BatchFeature) else _img_proc_output
439
-
440
- if len(text) == 0 and len(batched_images) > 0: # If text was initially None and images provided
441
- text = [" ".join([self.boi_token] * len(img_batch)) for img_batch in batched_images]
442
- elif len(batched_images) != len(text):
443
- raise ValueError(f"Inconsistent batch sizes: {len(batched_images)} images, {len(text)} texts")
444
-
445
- num_crops_popped = image_features_dict.pop("num_crops", None)
446
- if num_crops_popped is not None:
447
- num_crops_all = to_py_obj(num_crops_popped)
448
- # ... (user's complex crop and text modification logic - kept as per original) ...
449
- # This part needs careful attention to ensure num_crops_all aligns with batched_images
450
- # For simplicity, the following is a conceptual placeholder of the user's original intent
451
- processed_text_for_images = []
452
- current_crop_idx_offset = 0
453
- for batch_idx, (prompt, current_imgs_in_batch) in enumerate(zip(text, batched_images)):
454
- crops_for_this_batch_sample = []
455
- if num_crops_all: # Check if num_crops_all is not empty
456
- for _ in current_imgs_in_batch:
457
- if current_crop_idx_offset < len(num_crops_all):
458
- crops_for_this_batch_sample.append(num_crops_all[current_crop_idx_offset])
459
- current_crop_idx_offset += 1
460
- else:
461
- crops_for_this_batch_sample.append(0) # Should not happen
462
-
463
- image_indexes = [m.start() for m in re.finditer(re.escape(self.boi_token), prompt)]
464
- # ... (The rest of user's loop for image token replacement) ...
465
- # This was:
466
- # for num, idx in reversed(list(zip(crops_for_this_batch_sample, image_indexes))):
467
- # if num > 0 : ...
468
- # text[batch_idx] = prompt
469
- # For minimal change, I'll assume this part is complex and specific.
470
- # A simplified version:
471
- prompt_with_full_seq = prompt.replace(self.boi_token, self.full_image_sequence,
472
- len(current_imgs_in_batch) if image_indexes else 0)
473
- processed_text_for_images.append(prompt_with_full_seq)
474
- text = processed_text_for_images
475
- else: # if no num_crops, simpler replacement
476
- text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
477
-
478
- # --- Audio Processing ---
479
- audio_features_dict = {}
480
- if audios is not None:
481
- if self.audio_processor is None:
482
- raise ValueError("Audios were provided, but `self.audio_processor` is not set.")
483
-
484
- audio_call_kwargs = merged_call_kwargs.get("audio_kwargs", {})
485
- if sampling_rate is not None:
486
- audio_call_kwargs["sampling_rate"] = sampling_rate
487
-
488
- _audio_proc_output = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs)
489
- audio_features_dict = _audio_proc_output.data
490
- logger.debug(
491
- f"Gemma3OmniProcessor: Shape of 'audio_values' from Feature Extractor: {audio_features_dict['audio_values'].shape}")
492
 
493
- new_text_with_audio = []
494
- actual_mel_frames_per_sample = to_py_obj(audio_features_dict["audio_attention_mask"].sum(axis=-1))
495
 
496
- if len(actual_mel_frames_per_sample) != len(text): # Check batch consistency
497
  raise ValueError(
498
- f"Inconsistent batch sizes for audio and text: {len(actual_mel_frames_per_sample)} audio samples, {len(text)} texts.")
499
-
500
- for i, prompt in enumerate(text):
501
- num_soft_tokens = self._compute_audio_embed_size(actual_mel_frames_per_sample[i])
502
- # User's original audio_tokens dictionary for constructing the sequence
503
- _audio_token_str = self.audio_token_str_from_user_code # e.g. "<audio_soft_token>"
504
- _boa_token_str = getattr(self.tokenizer, "bos_token", " ") # Using BOS or space as BOA
505
- _eoa_token_str = getattr(self.tokenizer, "eos_token", "<|endoftext|>") # Using EOS as EOA
506
-
507
- audio_token_sequence_str = f"{_boa_token_str}{''.join([_audio_token_str] * num_soft_tokens)}{_eoa_token_str}"
508
-
509
- # User's replacement logic used boa_token as placeholder. This can be made more robust.
510
- # Using a dedicated placeholder is safer. For now, mimicking user's approach.
511
- # The user's code used `audio_tokens_map['boa_token']` (which was " ") as placeholder.
512
- placeholder_str = _boa_token_str
513
- if prompt.strip().startswith(
514
- placeholder_str.strip()) and placeholder_str.strip() != "": # Avoid replacing all spaces
515
- prompt = prompt.replace(placeholder_str, audio_token_sequence_str, 1) # Replace first
516
- elif self.audio_placeholder_token in prompt: # Check for a more explicit placeholder
517
- prompt = prompt.replace(self.audio_placeholder_token, audio_token_sequence_str, 1)
518
- else:
519
- prompt += audio_token_sequence_str
520
- new_text_with_audio.append(prompt)
521
- text = new_text_with_audio
522
-
523
- # --- Text Tokenization ---
524
- text_tokenizer_kwargs = merged_call_kwargs.get("text_kwargs", {})
525
- text_features_dict = self.tokenizer(text=text, return_tensors=None, **text_tokenizer_kwargs)
526
-
527
- # Debug log from user - ensure input_ids_list_of_lists is correctly formed
528
- input_ids_list_of_lists = text_features_dict["input_ids"]
529
- if not isinstance(input_ids_list_of_lists, list) or not (
530
- input_ids_list_of_lists and isinstance(input_ids_list_of_lists[0], list)):
531
- if isinstance(input_ids_list_of_lists, (torch.Tensor, np.ndarray)):
532
- input_ids_list_of_lists = to_py_obj(input_ids_list_of_lists)
533
- elif isinstance(input_ids_list_of_lists, list) and (
534
- not input_ids_list_of_lists or isinstance(input_ids_list_of_lists[0], int)):
535
- input_ids_list_of_lists = [input_ids_list_of_lists]
536
-
537
- for i, (txt, ids) in enumerate(zip(text, input_ids_list_of_lists)):
538
- if not isinstance(ids, list): ids = []
539
- audio_text_count = txt.count(self.audio_token_str_from_user_code)
540
- audio_ids_count = ids.count(self.audio_token_id)
 
 
 
 
 
 
541
  logger.debug(
542
- f"Sample {i}: Audio tokens ('{self.audio_token_str_from_user_code}') in text count={audio_text_count}, "
543
- f"in input_ids (ID:{self.audio_token_id}) count={audio_ids_count}. "
544
- f"Text snippet='{txt[:100]}...', Input IDs length={len(ids)}"
545
  )
546
 
547
- # Token type IDs from user's code
548
- # Convert to numpy for boolean indexing, then back to list.
549
- # This assumes input_ids_list_of_lists is now correctly a list of lists of ints.
550
- # To make it robust for padding, pad token_type_ids as well if input_ids are padded by tokenizer.
551
- # For now, assuming tokenizer with return_tensors=None gives unpadded list of lists.
552
- padded_input_ids_for_token_type, _ = self._pad_মাদের(input_ids_list_of_lists) # Custom helper needed
553
-
554
- mm_token_type_ids_np = np.zeros_like(padded_input_ids_for_token_type, dtype=int)
555
- if self.image_token_id is not None:
556
- mm_token_type_ids_np[padded_input_ids_for_token_type == self.image_token_id] = 1
557
- if self.audio_token_id != -1: # Check if audio_token_id is valid
558
- mm_token_type_ids_np[padded_input_ids_for_token_type == self.audio_token_id] = 2
559
- text_features_dict["token_type_ids"] = mm_token_type_ids_np.tolist()
560
-
561
- # Ensure attention_mask from tokenizer is also included if padding was applied by tokenizer
562
- # text_features_dict should already contain 'attention_mask' if padding=True for tokenizer
563
-
564
- final_batch_data = {**text_features_dict}
565
- if image_features_dict:
566
- final_batch_data.update(image_features_dict)
567
- if audio_features_dict:
568
- final_batch_data.update(audio_features_dict)
569
-
570
- return BatchFeature(data=final_batch_data, tensor_type=final_rt)
571
-
572
- # Helper for padding list of lists, if tokenizer does not do it with return_tensors=None
573
- def _pad_মাদের(self, list_of_lists: List[List[int]], padding_value: int = 0) -> Tuple[np.ndarray, np.ndarray]:
574
- if not list_of_lists: return np.array([]), np.array([])
575
- max_len = max(len(sublist) for sublist in list_of_lists)
576
- padded_array = np.full((len(list_of_lists), max_len), padding_value, dtype=int)
577
- attention_mask = np.zeros((len(list_of_lists), max_len), dtype=int)
578
- for i, sublist in enumerate(list_of_lists):
579
- padded_array[i, :len(sublist)] = sublist
580
- attention_mask[i, :len(sublist)] = 1
581
- return padded_array, attention_mask
582
 
583
  def batch_decode(self, *args, **kwargs):
584
  return self.tokenizer.batch_decode(*args, **kwargs)
@@ -588,18 +347,7 @@ class Gemma3OmniProcessor(ProcessorMixin):
588
 
589
  @property
590
  def model_input_names(self):
591
- # User's original logic, slightly more robust with hasattr checks
592
- tokenizer_inputs = []
593
- if hasattr(self, 'tokenizer') and self.tokenizer is not None:
594
- tokenizer_inputs = self.tokenizer.model_input_names + ["token_type_ids"]
595
-
596
- image_processor_inputs = []
597
- if hasattr(self, 'image_processor') and self.image_processor is not None:
598
- image_processor_inputs = self.image_processor.model_input_names
599
-
600
- audio_processor_inputs = []
601
- if hasattr(self, 'audio_processor') and self.audio_processor is not None:
602
- audio_processor_inputs = getattr(self.audio_processor, "model_input_names",
603
- ["audio_values", "audio_attention_mask"])
604
-
605
- return list(dict.fromkeys(tokenizer_inputs + image_processor_inputs + audio_processor_inputs))
 
1
  import re
2
+ from typing import List, Optional, Union, Dict, Any
3
 
4
  import math
5
  import numpy as np
6
  import scipy.signal
7
  import torch
8
  from torch.nn.utils.rnn import pad_sequence
9
+ from transformers.audio_utils import AudioInput
10
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
11
  from transformers.feature_extraction_utils import BatchFeature
12
  from transformers.image_utils import make_nested_list_of_images
13
+ from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs, Unpack
14
  from transformers.utils import TensorType, to_py_obj, logging
15
 
16
  # Constants
 
25
  IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
26
  AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
27
  DEFAULT_MAX_LENGTH = 16384
 
28
 
29
  logger = logging.get_logger(__name__)
30
 
 
32
  def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
33
  fmax: Optional[float] = None) -> np.ndarray:
34
  """Create Mel filterbank for audio processing."""
35
+ fmax = fmax or sampling_rate / 2
36
 
37
  def hz_to_mel(f: float) -> float:
38
  return 1127.0 * math.log(1 + f / 700.0)
39
 
 
 
 
40
  mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
41
  freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1)
 
42
  bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
 
43
 
44
  filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
45
+ for m in range(1, n_mels + 1):
46
+ left, center, right = bins[m - 1:m + 2]
47
+ filterbank[m - 1, left:center] = (np.arange(left, center) - left) / (center - left)
48
+ filterbank[m - 1, center:right] = (right - np.arange(center, right)) / (right - center)
 
 
 
 
 
 
 
 
 
49
 
50
  return filterbank
51
 
52
 
53
  class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
 
 
54
  def __init__(
55
  self,
56
  compression_rate: int = DEFAULT_COMPRESSION_RATE,
 
58
  feat_stride: int = DEFAULT_FEAT_STRIDE,
59
  sampling_rate: int = DEFAULT_SAMPLING_RATE,
60
  n_fft: int = DEFAULT_N_FFT,
61
+ win_length: int = DEFAULT_WIN_LENGTH,
62
+ hop_length: int = DEFAULT_HOP_LENGTH,
63
  n_mels: int = DEFAULT_N_MELS,
 
 
 
64
  **kwargs
65
  ):
 
 
 
 
 
 
 
 
 
66
  kwargs.pop("feature_size", None)
67
  kwargs.pop("sampling_rate", None)
68
  kwargs.pop("padding_value", None)
69
 
 
 
 
70
  super().__init__(
71
+ feature_size=n_mels,
72
+ sampling_rate=sampling_rate,
73
+ padding_value=0.0,
74
  **kwargs
75
  )
76
 
77
  self.compression_rate = compression_rate
78
  self.qformer_rate = qformer_rate
79
  self.feat_stride = feat_stride
80
+ self.sampling_rate = sampling_rate
81
 
82
+ self.window = np.hamming(win_length).astype(np.float32)
83
+ self.mel_filterbank = create_mel_filterbank(sampling_rate, n_fft, n_mels).T
84
  self.n_fft = n_fft
85
+ self.hop_length = hop_length
86
+ self.win_length = win_length
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  def __call__(
89
  self,
90
+ audios: List[AudioInput],
 
91
  return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
92
  ) -> BatchFeature:
93
+ features, sizes, frames = [], [], []
94
 
95
+ for wav in audios:
96
+ processed_wav = self._preprocess_audio(wav, 22500)
97
+ mel_spectrogram = self._compute_log_mel_spectrogram(processed_wav)
98
+ feature_tensor = torch.tensor(mel_spectrogram, dtype=torch.float32)
99
+ features.append(feature_tensor)
100
+ sizes.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
101
+ frames.append(feature_tensor.shape[0] * self.feat_stride)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ audio_embeds = pad_sequence(features, batch_first=True)
104
+ size_tensor = torch.stack(sizes)
 
105
 
106
+ attention_mask = None
107
+ if len(audios) > 1:
108
+ frame_lengths = torch.tensor(frames)
109
+ attention_mask = torch.arange(frame_lengths.max()).unsqueeze(0) < frame_lengths.unsqueeze(1)
 
 
 
 
 
 
 
110
 
111
  output_data = {
112
  "audio_values": audio_embeds,
113
+ "audio_values_sizes": size_tensor
114
  }
115
+ if attention_mask is not None:
116
+ output_data["audio_attention_mask"] = attention_mask
 
 
 
 
117
 
118
  return BatchFeature(data=output_data, tensor_type=return_tensors)
119
 
120
  def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray:
121
+ wav = torch.as_tensor(wav).float().numpy()
 
 
 
 
 
 
 
 
122
  if wav.ndim > 1:
123
  wav = wav.mean(axis=0)
 
124
  if source_sr != self.sampling_rate:
125
+ wav = scipy.signal.resample_poly(wav, self.sampling_rate, source_sr)
126
+ return wav / max(np.abs(wav).max(), 1e-6)
 
 
 
 
 
 
 
 
 
127
 
128
  def _compute_log_mel_spectrogram(self, wav: np.ndarray) -> np.ndarray:
129
+ frame_count = 1 + (len(wav) - self.win_length) // self.hop_length
130
+ strides = wav.strides[0]
131
+ frames = np.lib.stride_tricks.as_strided(
 
 
 
 
 
 
 
 
 
 
 
132
  wav,
133
+ shape=(frame_count, self.win_length),
134
+ strides=(strides * self.hop_length, strides),
135
  writeable=False
136
+ ).copy()
137
+ frames *= self.window
 
138
 
139
+ spectrum = np.fft.rfft(frames, n=self.n_fft).astype(np.complex64)
140
  power = np.abs(spectrum) ** 2
141
  mel_spectrogram = np.dot(power, self.mel_filterbank)
142
+ mel_spectrogram = np.clip(mel_spectrogram, 1.0, None)
143
+ return np.log(mel_spectrogram, dtype=np.float32)
 
 
144
 
145
  def _calculate_embed_length(self, frame_count: int) -> int:
146
  compressed = math.ceil(frame_count / self.compression_rate)
 
156
 
157
 
158
  class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
159
+ images_kwargs: Dict[str, Any]
160
+ audio_kwargs: Dict[str, Any]
 
161
  _defaults = {
162
  "text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH},
163
  "images_kwargs": {},
 
168
  class Gemma3OmniProcessor(ProcessorMixin):
169
  attributes = ["image_processor", "audio_processor", "tokenizer"]
170
  valid_kwargs = ["chat_template", "image_seq_length"]
 
 
171
  image_processor_class = "AutoImageProcessor"
172
+ audio_processor_class = "AutoFeatureExtractor"
173
  tokenizer_class = "AutoTokenizer"
174
 
175
  def __init__(
176
  self,
177
+ image_processor,
178
+ audio_processor,
179
+ tokenizer,
180
  chat_template=None,
181
  image_seq_length: int = 256,
182
+ **kwargs
183
  ):
184
+ self.image_seq_length = image_seq_length
185
+ self.image_token_id = tokenizer.image_token_id
186
+ self.boi_token = tokenizer.boi_token
187
+ self.image_token = tokenizer.image_token
188
+ self.audio_token = "<audio_soft_token>"
189
+ self.expected_audio_token_id = 262143
190
+ self.full_image_sequence = f"\n\n{tokenizer.boi_token}{''.join([tokenizer.image_token] * image_seq_length)}{tokenizer.eoi_token}\n\n"
191
+
192
+ self.compression_rate = 8
193
+ self.qformer_compression_rate = 1
194
+ self.feat_stride = 1
195
+
196
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
197
+ if self.audio_token_id != self.expected_audio_token_id:
198
+ logger.warning(
199
+ f"Assigned ID {self.audio_token_id} for '{self.audio_token}' does not match expected ID {self.expected_audio_token_id}. "
200
+ "Using assigned ID. Model embedding layer may need resizing."
201
+ )
202
+
203
  super().__init__(
204
  image_processor=image_processor,
205
  audio_processor=audio_processor,
 
208
  **kwargs
209
  )
210
 
211
+ def _merge_kwargs(self, ModelProcessorKwargs, tokenizer_init_kwargs, **kwargs):
212
+ default_kwargs = {}
213
+ for modality in ModelProcessorKwargs._defaults:
214
+ default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy()
215
+
216
+ # Update defaults with tokenizer init kwargs
217
+ for modality in default_kwargs:
218
+ modality_kwargs = default_kwargs[modality]
219
+ for key in modality_kwargs:
220
+ if key in tokenizer_init_kwargs:
221
+ value = (
222
+ getattr(self.tokenizer, key)
223
+ if hasattr(self.tokenizer, key)
224
+ else tokenizer_init_kwargs[key]
225
+ )
226
+ modality_kwargs[key] = value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
+ # Update with user-provided kwargs
229
+ for modality in default_kwargs:
230
+ if modality in kwargs:
231
+ default_kwargs[modality].update(kwargs[modality])
232
 
233
+ # Ensure text_kwargs has truncation=False and large max_length
234
+ default_kwargs["text_kwargs"]["truncation"] = False
235
+ default_kwargs["text_kwargs"]["max_length"] = default_kwargs["text_kwargs"].get("max_length",
236
+ DEFAULT_MAX_LENGTH)
237
 
238
+ return default_kwargs
239
+
240
+ def _compute_audio_embed_size(self, audio_frames: int) -> int:
241
+ result = math.ceil(audio_frames / self.compression_rate)
242
+ return math.ceil(result / self.qformer_compression_rate)
243
 
244
  def __call__(
245
  self,
246
+ images=None,
247
+ text=None,
248
+ videos=None,
249
+ audio=None,
250
+ **kwargs: Unpack[Gemma3ProcessorKwargs]
 
251
  ) -> BatchFeature:
252
+ if text is None and images is None:
253
+ raise ValueError("Provide at least one of `text` or `images`.")
254
 
255
+ output_kwargs = self._merge_kwargs(
256
+ Gemma3ProcessorKwargs,
257
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
 
258
  **kwargs
259
  )
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  if isinstance(text, str):
262
  text = [text]
263
+ elif not isinstance(text, list) or not all(isinstance(t, str) for t in text):
264
+ raise ValueError("Input text must be a string or list of strings")
265
 
266
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", "pt")
267
+ image_inputs = {}
268
  if images is not None:
 
 
269
  batched_images = make_nested_list_of_images(images)
270
+ image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
+ if not text:
273
+ text = [" ".join([self.boi_token] * len(images)) for images in batched_images]
274
 
275
+ if len(batched_images) != len(text):
276
  raise ValueError(
277
+ f"Inconsistent batch sizes: {len(batched_images)} images, {len(text)} texts"
278
+ )
279
+
280
+ num_crops = to_py_obj(image_inputs.pop("num_crops"))
281
+ batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images]
282
+
283
+ for batch_idx, (prompt, images, crops) in enumerate(zip(text, batched_images, batch_num_crops)):
284
+ image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
285
+ if len(images) != len(image_indexes):
286
+ raise ValueError(
287
+ f"Prompt has {len(image_indexes)} image tokens but received {len(images)} images"
288
+ )
289
+
290
+ for num, idx in reversed(list(zip(crops, image_indexes))):
291
+ if num:
292
+ formatted_image_text = (
293
+ f"Here is the original image {self.boi_token} and here are some crops to help you see better "
294
+ + " ".join([self.boi_token] * num)
295
+ )
296
+ prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token):]
297
+ text[batch_idx] = prompt
298
+
299
+ text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
300
+
301
+ audio_inputs = {}
302
+ if audio is not None:
303
+ audio_inputs = self.audio_processor(audio, return_tensors)
304
+ audio_embeds = audio_inputs['audio_values']
305
+ audio_frames = audio_embeds.shape[1] * self.feat_stride
306
+ audio_seq_length = self._compute_audio_embed_size(audio_frames)
307
+
308
+ audio_tokens = {
309
+ "boa_token": "<start_of_audio>",
310
+ "eoa_token": "<end_of_audio>",
311
+ "audio_token": "<audio_soft_token>",
312
+ "boa_token_id": 256001,
313
+ "eoa_token_id": 256002,
314
+ "audio_token_id": self.audio_token_id # Use dynamic ID
315
+ }
316
+
317
+ audio_sequence = f"\n\n{audio_tokens['boa_token']}{''.join([audio_tokens['audio_token']] * audio_seq_length)}{audio_tokens['eoa_token']}\n\n"
318
+ text = [prompt.replace(audio_tokens['boa_token'], audio_sequence) for prompt in text]
319
+
320
+ text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors=return_tensors)
321
+
322
+ # Debug: Log text and token counts before validation
323
+ for i, (txt, ids) in enumerate(zip(text, text_inputs["input_ids"])):
324
+ audio_text_count = txt.count(self.audio_token)
325
+ audio_ids_count = list(ids).count(self.audio_token_id)
326
  logger.debug(
327
+ f"Sample {i}: Audio tokens in text={audio_text_count}, in input_ids={audio_ids_count}, "
328
+ f"Text length={len(txt)}, Input IDs length={len(ids)}"
 
329
  )
330
 
331
+ array_ids = text_inputs["input_ids"]
332
+ if return_tensors == "pt":
333
+ mm_token_type_ids = torch.zeros_like(array_ids)
334
+ else:
335
+ mm_token_type_ids = np.zeros_like(array_ids)
336
+ mm_token_type_ids[array_ids == self.image_token_id] = 1 # Image token type
337
+ mm_token_type_ids[array_ids == self.audio_token_id] = 2 # Audio token type
338
+ text_inputs["token_type_ids"] = mm_token_type_ids
339
+
340
+ return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
  def batch_decode(self, *args, **kwargs):
343
  return self.tokenizer.batch_decode(*args, **kwargs)
 
347
 
348
  @property
349
  def model_input_names(self):
350
+ tokenizer_inputs = self.tokenizer.model_input_names + ["token_type_ids"]
351
+ image_processor_inputs = self.image_processor.model_input_names
352
+ audio_processor_inputs = self.audio_processor.model_input_names
353
+ return list(dict.fromkeys(tokenizer_inputs + image_processor_inputs + audio_processor_inputs))