Update processing_gemma3_omni.py
Browse files- processing_gemma3_omni.py +187 -439
processing_gemma3_omni.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
import re
|
2 |
-
from typing import List, Optional, Union, Dict, Any
|
3 |
|
4 |
import math
|
5 |
import numpy as np
|
6 |
import scipy.signal
|
7 |
import torch
|
8 |
from torch.nn.utils.rnn import pad_sequence
|
9 |
-
from transformers.audio_utils import AudioInput
|
10 |
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
|
11 |
from transformers.feature_extraction_utils import BatchFeature
|
12 |
from transformers.image_utils import make_nested_list_of_images
|
13 |
-
from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs
|
14 |
from transformers.utils import TensorType, to_py_obj, logging
|
15 |
|
16 |
# Constants
|
@@ -25,7 +25,6 @@ DEFAULT_FEAT_STRIDE = 4
|
|
25 |
IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
|
26 |
AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
|
27 |
DEFAULT_MAX_LENGTH = 16384
|
28 |
-
LOG_MEL_CLIP_EPSILON = 1e-5
|
29 |
|
30 |
logger = logging.get_logger(__name__)
|
31 |
|
@@ -33,41 +32,25 @@ logger = logging.get_logger(__name__)
|
|
33 |
def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
|
34 |
fmax: Optional[float] = None) -> np.ndarray:
|
35 |
"""Create Mel filterbank for audio processing."""
|
36 |
-
fmax = fmax or sampling_rate / 2
|
37 |
|
38 |
def hz_to_mel(f: float) -> float:
|
39 |
return 1127.0 * math.log(1 + f / 700.0)
|
40 |
|
41 |
-
if fmin >= fmax:
|
42 |
-
raise ValueError(f"fmin ({fmin}) must be smaller than fmax ({fmax}).")
|
43 |
-
|
44 |
mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
|
45 |
freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1)
|
46 |
-
freq_points = np.clip(freq_points, 0, sampling_rate / 2.0)
|
47 |
bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
|
48 |
-
bins = np.clip(bins, 0, n_fft // 2)
|
49 |
|
50 |
filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
|
51 |
-
for
|
52 |
-
left, center, right = bins[
|
53 |
-
|
54 |
-
|
55 |
-
# (small adjustment to ensure slopes are only added if points are distinct)
|
56 |
-
if center > left:
|
57 |
-
filterbank[m_idx, left:center] = (np.arange(left, center) - left) / (center - left)
|
58 |
-
if right > center:
|
59 |
-
filterbank[m_idx, center:right] = (right - np.arange(center, right)) / (right - center)
|
60 |
-
# Ensure peak is 1.0 if center is a valid point, particularly if left=center or center=right
|
61 |
-
# This covers the case where a slope might not set the peak to 1 due to integer arithmetic.
|
62 |
-
if left <= center <= right and ((center > left and center <= right) or (center < right and center >= left)):
|
63 |
-
filterbank[m_idx, center] = 1.0
|
64 |
|
65 |
return filterbank
|
66 |
|
67 |
|
68 |
class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
|
69 |
-
model_input_names = ["audio_values", "audio_attention_mask"]
|
70 |
-
|
71 |
def __init__(
|
72 |
self,
|
73 |
compression_rate: int = DEFAULT_COMPRESSION_RATE,
|
@@ -75,182 +58,89 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
|
|
75 |
feat_stride: int = DEFAULT_FEAT_STRIDE,
|
76 |
sampling_rate: int = DEFAULT_SAMPLING_RATE,
|
77 |
n_fft: int = DEFAULT_N_FFT,
|
78 |
-
win_length:
|
79 |
-
hop_length:
|
80 |
n_mels: int = DEFAULT_N_MELS,
|
81 |
-
f_min: float = 0.0,
|
82 |
-
f_max: Optional[float] = None,
|
83 |
-
padding_value: float = 0.0,
|
84 |
**kwargs
|
85 |
):
|
86 |
-
# Pop these before super().__init__ as they might conflict if also in kwargs
|
87 |
-
# and super() doesn't expect them, or if super() expects them but under different names.
|
88 |
-
# However, feature_size, sampling_rate, padding_value ARE arguments for SequenceFeatureExtractor.
|
89 |
-
# So, ensure they are passed correctly.
|
90 |
-
_feature_size = n_mels
|
91 |
-
_sampling_rate = sampling_rate
|
92 |
-
_padding_value = padding_value
|
93 |
-
|
94 |
-
# Remove them from kwargs if they were also passed via kwargs to avoid duplicate argument error
|
95 |
kwargs.pop("feature_size", None)
|
96 |
kwargs.pop("sampling_rate", None)
|
97 |
kwargs.pop("padding_value", None)
|
98 |
|
99 |
-
_win_length = win_length if win_length is not None else n_fft
|
100 |
-
_hop_length = hop_length if hop_length is not None else _win_length // 4
|
101 |
-
|
102 |
super().__init__(
|
103 |
-
feature_size=
|
104 |
-
sampling_rate=
|
105 |
-
padding_value=
|
106 |
**kwargs
|
107 |
)
|
108 |
|
109 |
self.compression_rate = compression_rate
|
110 |
self.qformer_rate = qformer_rate
|
111 |
self.feat_stride = feat_stride
|
112 |
-
|
113 |
|
|
|
|
|
114 |
self.n_fft = n_fft
|
115 |
-
self.
|
116 |
-
self.
|
117 |
-
self.n_mels = n_mels
|
118 |
-
self.f_min = f_min
|
119 |
-
self.f_max = f_max
|
120 |
-
|
121 |
-
if self.win_length > self.n_fft:
|
122 |
-
logger.warning(
|
123 |
-
f"win_length ({self.win_length}) is greater than n_fft ({self.n_fft}). "
|
124 |
-
"Window will be applied, then data will be zero-padded/truncated to n_fft by np.fft.rfft."
|
125 |
-
)
|
126 |
-
self.window = np.hamming(self.win_length).astype(np.float32)
|
127 |
-
self.mel_filterbank = create_mel_filterbank(
|
128 |
-
self.sampling_rate, self.n_fft, self.n_mels, fmin=self.f_min, fmax=self.f_max
|
129 |
-
).T
|
130 |
|
131 |
def __call__(
|
132 |
self,
|
133 |
-
audios:
|
134 |
-
sampling_rate: Optional[int] = None,
|
135 |
return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
|
136 |
) -> BatchFeature:
|
|
|
137 |
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
for audio_item in audios:
|
147 |
-
current_wav: np.ndarray
|
148 |
-
source_sr: int
|
149 |
-
|
150 |
-
if isinstance(audio_item, tuple) and len(audio_item) == 2 and isinstance(audio_item[1], int):
|
151 |
-
current_wav, source_sr = audio_item
|
152 |
-
current_wav = np.asarray(current_wav, dtype=np.float32)
|
153 |
-
elif isinstance(audio_item, (np.ndarray, list)):
|
154 |
-
current_wav = np.asarray(audio_item, dtype=np.float32)
|
155 |
-
if sampling_rate is None:
|
156 |
-
raise ValueError(
|
157 |
-
"sampling_rate must be provided if audio inputs are raw numpy arrays or lists without sr."
|
158 |
-
)
|
159 |
-
source_sr = sampling_rate
|
160 |
-
else:
|
161 |
-
raise TypeError(
|
162 |
-
f"Unsupported audio input type: {type(audio_item)}. "
|
163 |
-
"Expected np.ndarray, list of floats, or Tuple[np.ndarray, int]."
|
164 |
-
)
|
165 |
-
|
166 |
-
processed_wav_array = self._preprocess_audio(current_wav, source_sr)
|
167 |
-
mel_spectrogram = self._compute_log_mel_spectrogram(processed_wav_array)
|
168 |
|
169 |
-
|
170 |
-
|
171 |
-
actual_mel_lengths.append(feature_tensor.shape[0])
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
max_t_mel_in_batch = audio_embeds.shape[1]
|
179 |
-
|
180 |
-
attention_mask = torch.zeros(len(audios), max_t_mel_in_batch,
|
181 |
-
dtype=torch.bool) # Device handled by BatchFeature
|
182 |
-
for i, length in enumerate(actual_mel_lengths):
|
183 |
-
attention_mask[i, :length] = True
|
184 |
|
185 |
output_data = {
|
186 |
"audio_values": audio_embeds,
|
187 |
-
"
|
188 |
}
|
189 |
-
|
190 |
-
|
191 |
-
output_data["audio_values_sizes"] = torch.stack(sizes_for_embed_length)
|
192 |
-
|
193 |
-
logger.debug(
|
194 |
-
f"Gemma3AudioFeatureExtractor: Output 'audio_values' shape: {output_data['audio_values'].shape}") # Verify output
|
195 |
|
196 |
return BatchFeature(data=output_data, tensor_type=return_tensors)
|
197 |
|
198 |
def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray:
|
199 |
-
|
200 |
-
if np.issubdtype(wav.dtype, np.integer):
|
201 |
-
max_val = np.iinfo(wav.dtype).max if wav.size > 0 else 1.0
|
202 |
-
wav = wav.astype(np.float32) / max_val
|
203 |
-
else:
|
204 |
-
wav = wav.astype(np.float32)
|
205 |
-
elif wav.dtype == np.float64:
|
206 |
-
wav = wav.astype(np.float32)
|
207 |
-
|
208 |
if wav.ndim > 1:
|
209 |
wav = wav.mean(axis=0)
|
210 |
-
|
211 |
if source_sr != self.sampling_rate:
|
212 |
-
|
213 |
-
|
214 |
-
up_factor = self.sampling_rate // common_divisor
|
215 |
-
down_factor = source_sr // common_divisor
|
216 |
-
if up_factor != down_factor:
|
217 |
-
wav = scipy.signal.resample_poly(wav, up=up_factor, down=down_factor)
|
218 |
-
|
219 |
-
max_abs_val = np.abs(wav).max()
|
220 |
-
if max_abs_val > 1e-7:
|
221 |
-
wav = wav / max_abs_val
|
222 |
-
return wav
|
223 |
|
224 |
def _compute_log_mel_spectrogram(self, wav: np.ndarray) -> np.ndarray:
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
if len(wav) >= self.win_length:
|
230 |
-
num_frames = 1 + (len(wav) - self.win_length) // self.hop_length
|
231 |
-
else:
|
232 |
-
num_frames = 0
|
233 |
-
|
234 |
-
if num_frames <= 0:
|
235 |
-
# logger.warning(...) # logger might not be defined
|
236 |
-
return np.zeros((0, self.n_mels), dtype=np.float32)
|
237 |
-
|
238 |
-
frames_view = np.lib.stride_tricks.as_strided(
|
239 |
wav,
|
240 |
-
shape=(
|
241 |
-
strides=(
|
242 |
writeable=False
|
243 |
-
)
|
244 |
-
|
245 |
-
frames_data *= self.window
|
246 |
|
247 |
-
spectrum = np.fft.rfft(
|
248 |
power = np.abs(spectrum) ** 2
|
249 |
mel_spectrogram = np.dot(power, self.mel_filterbank)
|
250 |
-
mel_spectrogram = np.clip(mel_spectrogram,
|
251 |
-
|
252 |
-
|
253 |
-
return log_mel_spectrogram.astype(np.float32)
|
254 |
|
255 |
def _calculate_embed_length(self, frame_count: int) -> int:
|
256 |
compressed = math.ceil(frame_count / self.compression_rate)
|
@@ -266,9 +156,8 @@ class Gemma3ImagesKwargs(ImagesKwargs):
|
|
266 |
|
267 |
|
268 |
class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
|
269 |
-
images_kwargs:
|
270 |
-
audio_kwargs:
|
271 |
-
text_kwargs: Optional[Dict[str, Any]] = None
|
272 |
_defaults = {
|
273 |
"text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH},
|
274 |
"images_kwargs": {},
|
@@ -279,23 +168,38 @@ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
|
|
279 |
class Gemma3OmniProcessor(ProcessorMixin):
|
280 |
attributes = ["image_processor", "audio_processor", "tokenizer"]
|
281 |
valid_kwargs = ["chat_template", "image_seq_length"]
|
282 |
-
|
283 |
-
# --- CRITICAL FIX: Use STRING names for auto-loading by ProcessorMixin ---
|
284 |
image_processor_class = "AutoImageProcessor"
|
285 |
-
audio_processor_class = "AutoFeatureExtractor"
|
286 |
tokenizer_class = "AutoTokenizer"
|
287 |
|
288 |
def __init__(
|
289 |
self,
|
290 |
-
image_processor
|
291 |
-
audio_processor
|
292 |
-
tokenizer
|
293 |
chat_template=None,
|
294 |
image_seq_length: int = 256,
|
295 |
-
**kwargs
|
296 |
):
|
297 |
-
|
298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
super().__init__(
|
300 |
image_processor=image_processor,
|
301 |
audio_processor=audio_processor,
|
@@ -304,281 +208,136 @@ class Gemma3OmniProcessor(ProcessorMixin):
|
|
304 |
**kwargs
|
305 |
)
|
306 |
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
# User's original warning logic for audio_token_id
|
324 |
-
# if self.audio_token_id != self.expected_audio_token_id: # Comparing to a fixed ID
|
325 |
-
# logger.warning(f"Assigned ID {self.audio_token_id} for '{self.audio_token_str_from_user_code}' does not match expected ID {self.expected_audio_token_id}.")
|
326 |
-
if hasattr(self.tokenizer, "unk_token_id") and self.audio_token_id == self.tokenizer.unk_token_id:
|
327 |
-
logger.warning(
|
328 |
-
f"Audio token '{self.audio_token_str_from_user_code}' not found in tokenizer, maps to UNK. Ensure it's added as a special token.")
|
329 |
-
|
330 |
-
self.full_image_sequence = f"\n\n{self.boi_token}{''.join([self.image_token] * image_seq_length)}{self.eoi_token}\n\n"
|
331 |
-
else:
|
332 |
-
# This case should ideally not happen if from_pretrained works correctly.
|
333 |
-
logger.error(
|
334 |
-
"Gemma3OmniProcessor initialized, but tokenizer is None. Token-dependent attributes will be missing or use placeholders.")
|
335 |
-
self.image_token_id = None
|
336 |
-
self.boi_token = "<UNUSED_BOI>"
|
337 |
-
self.image_token = "<UNUSED_IMG_TOKEN>"
|
338 |
-
self.eoi_token = "<UNUSED_EOI>"
|
339 |
-
self.audio_token_str_from_user_code = "<audio_soft_token>"
|
340 |
-
self.audio_token_id = -1
|
341 |
-
self.full_image_sequence = ""
|
342 |
-
|
343 |
-
# These are parameters for this processor's logic of determining audio token sequence length for prompts
|
344 |
-
# They were fixed values in user's original __init__
|
345 |
-
self.prompt_audio_compression_rate = 8
|
346 |
-
self.prompt_audio_qformer_compression_rate = 1
|
347 |
-
self.prompt_audio_feat_stride = 1
|
348 |
-
|
349 |
-
def _merge_kwargs(self, KwargsClassWithDefaults, tokenizer_init_kwargs, **kwargs_from_call):
|
350 |
-
final_kwargs = {}
|
351 |
-
_defaults = getattr(KwargsClassWithDefaults, "_defaults", {})
|
352 |
-
if not isinstance(_defaults, dict): _defaults = {}
|
353 |
-
|
354 |
-
for modality_key, default_modality_kwargs in _defaults.items():
|
355 |
-
final_kwargs[modality_key] = default_modality_kwargs.copy()
|
356 |
-
|
357 |
-
for modality_key_in_call, modality_kwargs_in_call in kwargs_from_call.items():
|
358 |
-
if modality_key_in_call in final_kwargs:
|
359 |
-
if isinstance(modality_kwargs_in_call, dict):
|
360 |
-
final_kwargs[modality_key_in_call].update(modality_kwargs_in_call)
|
361 |
-
elif isinstance(modality_kwargs_in_call, dict):
|
362 |
-
final_kwargs[modality_key_in_call] = modality_kwargs_in_call.copy()
|
363 |
-
|
364 |
-
for modality_key in final_kwargs:
|
365 |
-
modality_dict = final_kwargs[modality_key]
|
366 |
-
if isinstance(modality_dict, dict) and self.tokenizer is not None: # Check tokenizer exists
|
367 |
-
for key_in_mod_dict in list(modality_dict.keys()):
|
368 |
-
if key_in_mod_dict in tokenizer_init_kwargs:
|
369 |
-
value = (
|
370 |
-
getattr(self.tokenizer, key_in_mod_dict)
|
371 |
-
if hasattr(self.tokenizer, key_in_mod_dict)
|
372 |
-
else tokenizer_init_kwargs[key_in_mod_dict]
|
373 |
-
)
|
374 |
-
modality_dict[key_in_mod_dict] = value
|
375 |
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
|
381 |
-
|
|
|
|
|
|
|
382 |
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
result = math.ceil(
|
387 |
-
return math.ceil(result / self.
|
388 |
|
389 |
def __call__(
|
390 |
self,
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
**kwargs: Any
|
397 |
) -> BatchFeature:
|
398 |
-
if text is None and images is None
|
399 |
-
raise ValueError("Provide at least one of `text
|
400 |
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
self.tokenizer.init_kwargs if hasattr(self.tokenizer, 'init_kwargs') else {},
|
405 |
**kwargs
|
406 |
)
|
407 |
|
408 |
-
if final_rt is None:
|
409 |
-
final_rt = merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", TensorType.PYTORCH)
|
410 |
-
else:
|
411 |
-
merged_call_kwargs.get("text_kwargs", {}).pop("return_tensors", None)
|
412 |
-
|
413 |
-
if text is None:
|
414 |
-
num_samples = 0
|
415 |
-
if images is not None:
|
416 |
-
_images_list = images if isinstance(images, list) and (
|
417 |
-
not images or not isinstance(images[0], (int, float))) else [images]
|
418 |
-
num_samples = len(_images_list)
|
419 |
-
elif audios is not None:
|
420 |
-
_audios_list = audios if isinstance(audios, list) else [audios]
|
421 |
-
num_samples = len(_audios_list)
|
422 |
-
text = [""] * num_samples if num_samples > 0 else [""]
|
423 |
-
|
424 |
if isinstance(text, str):
|
425 |
text = [text]
|
426 |
-
|
427 |
-
raise ValueError("Input
|
428 |
|
429 |
-
|
430 |
-
|
431 |
if images is not None:
|
432 |
-
if self.image_processor is None:
|
433 |
-
raise ValueError("Images were provided, but `self.image_processor` is not set.")
|
434 |
batched_images = make_nested_list_of_images(images)
|
435 |
-
|
436 |
-
**merged_call_kwargs.get("images_kwargs", {}))
|
437 |
-
image_features_dict = _img_proc_output.data if isinstance(_img_proc_output,
|
438 |
-
BatchFeature) else _img_proc_output
|
439 |
-
|
440 |
-
if len(text) == 0 and len(batched_images) > 0: # If text was initially None and images provided
|
441 |
-
text = [" ".join([self.boi_token] * len(img_batch)) for img_batch in batched_images]
|
442 |
-
elif len(batched_images) != len(text):
|
443 |
-
raise ValueError(f"Inconsistent batch sizes: {len(batched_images)} images, {len(text)} texts")
|
444 |
-
|
445 |
-
num_crops_popped = image_features_dict.pop("num_crops", None)
|
446 |
-
if num_crops_popped is not None:
|
447 |
-
num_crops_all = to_py_obj(num_crops_popped)
|
448 |
-
# ... (user's complex crop and text modification logic - kept as per original) ...
|
449 |
-
# This part needs careful attention to ensure num_crops_all aligns with batched_images
|
450 |
-
# For simplicity, the following is a conceptual placeholder of the user's original intent
|
451 |
-
processed_text_for_images = []
|
452 |
-
current_crop_idx_offset = 0
|
453 |
-
for batch_idx, (prompt, current_imgs_in_batch) in enumerate(zip(text, batched_images)):
|
454 |
-
crops_for_this_batch_sample = []
|
455 |
-
if num_crops_all: # Check if num_crops_all is not empty
|
456 |
-
for _ in current_imgs_in_batch:
|
457 |
-
if current_crop_idx_offset < len(num_crops_all):
|
458 |
-
crops_for_this_batch_sample.append(num_crops_all[current_crop_idx_offset])
|
459 |
-
current_crop_idx_offset += 1
|
460 |
-
else:
|
461 |
-
crops_for_this_batch_sample.append(0) # Should not happen
|
462 |
-
|
463 |
-
image_indexes = [m.start() for m in re.finditer(re.escape(self.boi_token), prompt)]
|
464 |
-
# ... (The rest of user's loop for image token replacement) ...
|
465 |
-
# This was:
|
466 |
-
# for num, idx in reversed(list(zip(crops_for_this_batch_sample, image_indexes))):
|
467 |
-
# if num > 0 : ...
|
468 |
-
# text[batch_idx] = prompt
|
469 |
-
# For minimal change, I'll assume this part is complex and specific.
|
470 |
-
# A simplified version:
|
471 |
-
prompt_with_full_seq = prompt.replace(self.boi_token, self.full_image_sequence,
|
472 |
-
len(current_imgs_in_batch) if image_indexes else 0)
|
473 |
-
processed_text_for_images.append(prompt_with_full_seq)
|
474 |
-
text = processed_text_for_images
|
475 |
-
else: # if no num_crops, simpler replacement
|
476 |
-
text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
|
477 |
-
|
478 |
-
# --- Audio Processing ---
|
479 |
-
audio_features_dict = {}
|
480 |
-
if audios is not None:
|
481 |
-
if self.audio_processor is None:
|
482 |
-
raise ValueError("Audios were provided, but `self.audio_processor` is not set.")
|
483 |
-
|
484 |
-
audio_call_kwargs = merged_call_kwargs.get("audio_kwargs", {})
|
485 |
-
if sampling_rate is not None:
|
486 |
-
audio_call_kwargs["sampling_rate"] = sampling_rate
|
487 |
-
|
488 |
-
_audio_proc_output = self.audio_processor(audios=audios, return_tensors=None, **audio_call_kwargs)
|
489 |
-
audio_features_dict = _audio_proc_output.data
|
490 |
-
logger.debug(
|
491 |
-
f"Gemma3OmniProcessor: Shape of 'audio_values' from Feature Extractor: {audio_features_dict['audio_values'].shape}")
|
492 |
|
493 |
-
|
494 |
-
|
495 |
|
496 |
-
if len(
|
497 |
raise ValueError(
|
498 |
-
f"Inconsistent batch sizes
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
logger.debug(
|
542 |
-
f"Sample {i}: Audio tokens
|
543 |
-
f"
|
544 |
-
f"Text snippet='{txt[:100]}...', Input IDs length={len(ids)}"
|
545 |
)
|
546 |
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
if self.audio_token_id != -1: # Check if audio_token_id is valid
|
558 |
-
mm_token_type_ids_np[padded_input_ids_for_token_type == self.audio_token_id] = 2
|
559 |
-
text_features_dict["token_type_ids"] = mm_token_type_ids_np.tolist()
|
560 |
-
|
561 |
-
# Ensure attention_mask from tokenizer is also included if padding was applied by tokenizer
|
562 |
-
# text_features_dict should already contain 'attention_mask' if padding=True for tokenizer
|
563 |
-
|
564 |
-
final_batch_data = {**text_features_dict}
|
565 |
-
if image_features_dict:
|
566 |
-
final_batch_data.update(image_features_dict)
|
567 |
-
if audio_features_dict:
|
568 |
-
final_batch_data.update(audio_features_dict)
|
569 |
-
|
570 |
-
return BatchFeature(data=final_batch_data, tensor_type=final_rt)
|
571 |
-
|
572 |
-
# Helper for padding list of lists, if tokenizer does not do it with return_tensors=None
|
573 |
-
def _pad_মাদের(self, list_of_lists: List[List[int]], padding_value: int = 0) -> Tuple[np.ndarray, np.ndarray]:
|
574 |
-
if not list_of_lists: return np.array([]), np.array([])
|
575 |
-
max_len = max(len(sublist) for sublist in list_of_lists)
|
576 |
-
padded_array = np.full((len(list_of_lists), max_len), padding_value, dtype=int)
|
577 |
-
attention_mask = np.zeros((len(list_of_lists), max_len), dtype=int)
|
578 |
-
for i, sublist in enumerate(list_of_lists):
|
579 |
-
padded_array[i, :len(sublist)] = sublist
|
580 |
-
attention_mask[i, :len(sublist)] = 1
|
581 |
-
return padded_array, attention_mask
|
582 |
|
583 |
def batch_decode(self, *args, **kwargs):
|
584 |
return self.tokenizer.batch_decode(*args, **kwargs)
|
@@ -588,18 +347,7 @@ class Gemma3OmniProcessor(ProcessorMixin):
|
|
588 |
|
589 |
@property
|
590 |
def model_input_names(self):
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
image_processor_inputs = []
|
597 |
-
if hasattr(self, 'image_processor') and self.image_processor is not None:
|
598 |
-
image_processor_inputs = self.image_processor.model_input_names
|
599 |
-
|
600 |
-
audio_processor_inputs = []
|
601 |
-
if hasattr(self, 'audio_processor') and self.audio_processor is not None:
|
602 |
-
audio_processor_inputs = getattr(self.audio_processor, "model_input_names",
|
603 |
-
["audio_values", "audio_attention_mask"])
|
604 |
-
|
605 |
-
return list(dict.fromkeys(tokenizer_inputs + image_processor_inputs + audio_processor_inputs))
|
|
|
1 |
import re
|
2 |
+
from typing import List, Optional, Union, Dict, Any
|
3 |
|
4 |
import math
|
5 |
import numpy as np
|
6 |
import scipy.signal
|
7 |
import torch
|
8 |
from torch.nn.utils.rnn import pad_sequence
|
9 |
+
from transformers.audio_utils import AudioInput
|
10 |
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
|
11 |
from transformers.feature_extraction_utils import BatchFeature
|
12 |
from transformers.image_utils import make_nested_list_of_images
|
13 |
+
from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs, Unpack
|
14 |
from transformers.utils import TensorType, to_py_obj, logging
|
15 |
|
16 |
# Constants
|
|
|
25 |
IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
|
26 |
AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
|
27 |
DEFAULT_MAX_LENGTH = 16384
|
|
|
28 |
|
29 |
logger = logging.get_logger(__name__)
|
30 |
|
|
|
32 |
def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
|
33 |
fmax: Optional[float] = None) -> np.ndarray:
|
34 |
"""Create Mel filterbank for audio processing."""
|
35 |
+
fmax = fmax or sampling_rate / 2
|
36 |
|
37 |
def hz_to_mel(f: float) -> float:
|
38 |
return 1127.0 * math.log(1 + f / 700.0)
|
39 |
|
|
|
|
|
|
|
40 |
mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
|
41 |
freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1)
|
|
|
42 |
bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
|
|
|
43 |
|
44 |
filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
|
45 |
+
for m in range(1, n_mels + 1):
|
46 |
+
left, center, right = bins[m - 1:m + 2]
|
47 |
+
filterbank[m - 1, left:center] = (np.arange(left, center) - left) / (center - left)
|
48 |
+
filterbank[m - 1, center:right] = (right - np.arange(center, right)) / (right - center)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
return filterbank
|
51 |
|
52 |
|
53 |
class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
|
|
|
|
|
54 |
def __init__(
|
55 |
self,
|
56 |
compression_rate: int = DEFAULT_COMPRESSION_RATE,
|
|
|
58 |
feat_stride: int = DEFAULT_FEAT_STRIDE,
|
59 |
sampling_rate: int = DEFAULT_SAMPLING_RATE,
|
60 |
n_fft: int = DEFAULT_N_FFT,
|
61 |
+
win_length: int = DEFAULT_WIN_LENGTH,
|
62 |
+
hop_length: int = DEFAULT_HOP_LENGTH,
|
63 |
n_mels: int = DEFAULT_N_MELS,
|
|
|
|
|
|
|
64 |
**kwargs
|
65 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
kwargs.pop("feature_size", None)
|
67 |
kwargs.pop("sampling_rate", None)
|
68 |
kwargs.pop("padding_value", None)
|
69 |
|
|
|
|
|
|
|
70 |
super().__init__(
|
71 |
+
feature_size=n_mels,
|
72 |
+
sampling_rate=sampling_rate,
|
73 |
+
padding_value=0.0,
|
74 |
**kwargs
|
75 |
)
|
76 |
|
77 |
self.compression_rate = compression_rate
|
78 |
self.qformer_rate = qformer_rate
|
79 |
self.feat_stride = feat_stride
|
80 |
+
self.sampling_rate = sampling_rate
|
81 |
|
82 |
+
self.window = np.hamming(win_length).astype(np.float32)
|
83 |
+
self.mel_filterbank = create_mel_filterbank(sampling_rate, n_fft, n_mels).T
|
84 |
self.n_fft = n_fft
|
85 |
+
self.hop_length = hop_length
|
86 |
+
self.win_length = win_length
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
def __call__(
|
89 |
self,
|
90 |
+
audios: List[AudioInput],
|
|
|
91 |
return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
|
92 |
) -> BatchFeature:
|
93 |
+
features, sizes, frames = [], [], []
|
94 |
|
95 |
+
for wav in audios:
|
96 |
+
processed_wav = self._preprocess_audio(wav, 22500)
|
97 |
+
mel_spectrogram = self._compute_log_mel_spectrogram(processed_wav)
|
98 |
+
feature_tensor = torch.tensor(mel_spectrogram, dtype=torch.float32)
|
99 |
+
features.append(feature_tensor)
|
100 |
+
sizes.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
|
101 |
+
frames.append(feature_tensor.shape[0] * self.feat_stride)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
+
audio_embeds = pad_sequence(features, batch_first=True)
|
104 |
+
size_tensor = torch.stack(sizes)
|
|
|
105 |
|
106 |
+
attention_mask = None
|
107 |
+
if len(audios) > 1:
|
108 |
+
frame_lengths = torch.tensor(frames)
|
109 |
+
attention_mask = torch.arange(frame_lengths.max()).unsqueeze(0) < frame_lengths.unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
output_data = {
|
112 |
"audio_values": audio_embeds,
|
113 |
+
"audio_values_sizes": size_tensor
|
114 |
}
|
115 |
+
if attention_mask is not None:
|
116 |
+
output_data["audio_attention_mask"] = attention_mask
|
|
|
|
|
|
|
|
|
117 |
|
118 |
return BatchFeature(data=output_data, tensor_type=return_tensors)
|
119 |
|
120 |
def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray:
|
121 |
+
wav = torch.as_tensor(wav).float().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
if wav.ndim > 1:
|
123 |
wav = wav.mean(axis=0)
|
|
|
124 |
if source_sr != self.sampling_rate:
|
125 |
+
wav = scipy.signal.resample_poly(wav, self.sampling_rate, source_sr)
|
126 |
+
return wav / max(np.abs(wav).max(), 1e-6)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
def _compute_log_mel_spectrogram(self, wav: np.ndarray) -> np.ndarray:
|
129 |
+
frame_count = 1 + (len(wav) - self.win_length) // self.hop_length
|
130 |
+
strides = wav.strides[0]
|
131 |
+
frames = np.lib.stride_tricks.as_strided(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
wav,
|
133 |
+
shape=(frame_count, self.win_length),
|
134 |
+
strides=(strides * self.hop_length, strides),
|
135 |
writeable=False
|
136 |
+
).copy()
|
137 |
+
frames *= self.window
|
|
|
138 |
|
139 |
+
spectrum = np.fft.rfft(frames, n=self.n_fft).astype(np.complex64)
|
140 |
power = np.abs(spectrum) ** 2
|
141 |
mel_spectrogram = np.dot(power, self.mel_filterbank)
|
142 |
+
mel_spectrogram = np.clip(mel_spectrogram, 1.0, None)
|
143 |
+
return np.log(mel_spectrogram, dtype=np.float32)
|
|
|
|
|
144 |
|
145 |
def _calculate_embed_length(self, frame_count: int) -> int:
|
146 |
compressed = math.ceil(frame_count / self.compression_rate)
|
|
|
156 |
|
157 |
|
158 |
class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
|
159 |
+
images_kwargs: Dict[str, Any]
|
160 |
+
audio_kwargs: Dict[str, Any]
|
|
|
161 |
_defaults = {
|
162 |
"text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH},
|
163 |
"images_kwargs": {},
|
|
|
168 |
class Gemma3OmniProcessor(ProcessorMixin):
|
169 |
attributes = ["image_processor", "audio_processor", "tokenizer"]
|
170 |
valid_kwargs = ["chat_template", "image_seq_length"]
|
|
|
|
|
171 |
image_processor_class = "AutoImageProcessor"
|
172 |
+
audio_processor_class = "AutoFeatureExtractor"
|
173 |
tokenizer_class = "AutoTokenizer"
|
174 |
|
175 |
def __init__(
|
176 |
self,
|
177 |
+
image_processor,
|
178 |
+
audio_processor,
|
179 |
+
tokenizer,
|
180 |
chat_template=None,
|
181 |
image_seq_length: int = 256,
|
182 |
+
**kwargs
|
183 |
):
|
184 |
+
self.image_seq_length = image_seq_length
|
185 |
+
self.image_token_id = tokenizer.image_token_id
|
186 |
+
self.boi_token = tokenizer.boi_token
|
187 |
+
self.image_token = tokenizer.image_token
|
188 |
+
self.audio_token = "<audio_soft_token>"
|
189 |
+
self.expected_audio_token_id = 262143
|
190 |
+
self.full_image_sequence = f"\n\n{tokenizer.boi_token}{''.join([tokenizer.image_token] * image_seq_length)}{tokenizer.eoi_token}\n\n"
|
191 |
+
|
192 |
+
self.compression_rate = 8
|
193 |
+
self.qformer_compression_rate = 1
|
194 |
+
self.feat_stride = 1
|
195 |
+
|
196 |
+
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
|
197 |
+
if self.audio_token_id != self.expected_audio_token_id:
|
198 |
+
logger.warning(
|
199 |
+
f"Assigned ID {self.audio_token_id} for '{self.audio_token}' does not match expected ID {self.expected_audio_token_id}. "
|
200 |
+
"Using assigned ID. Model embedding layer may need resizing."
|
201 |
+
)
|
202 |
+
|
203 |
super().__init__(
|
204 |
image_processor=image_processor,
|
205 |
audio_processor=audio_processor,
|
|
|
208 |
**kwargs
|
209 |
)
|
210 |
|
211 |
+
def _merge_kwargs(self, ModelProcessorKwargs, tokenizer_init_kwargs, **kwargs):
|
212 |
+
default_kwargs = {}
|
213 |
+
for modality in ModelProcessorKwargs._defaults:
|
214 |
+
default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy()
|
215 |
+
|
216 |
+
# Update defaults with tokenizer init kwargs
|
217 |
+
for modality in default_kwargs:
|
218 |
+
modality_kwargs = default_kwargs[modality]
|
219 |
+
for key in modality_kwargs:
|
220 |
+
if key in tokenizer_init_kwargs:
|
221 |
+
value = (
|
222 |
+
getattr(self.tokenizer, key)
|
223 |
+
if hasattr(self.tokenizer, key)
|
224 |
+
else tokenizer_init_kwargs[key]
|
225 |
+
)
|
226 |
+
modality_kwargs[key] = value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
+
# Update with user-provided kwargs
|
229 |
+
for modality in default_kwargs:
|
230 |
+
if modality in kwargs:
|
231 |
+
default_kwargs[modality].update(kwargs[modality])
|
232 |
|
233 |
+
# Ensure text_kwargs has truncation=False and large max_length
|
234 |
+
default_kwargs["text_kwargs"]["truncation"] = False
|
235 |
+
default_kwargs["text_kwargs"]["max_length"] = default_kwargs["text_kwargs"].get("max_length",
|
236 |
+
DEFAULT_MAX_LENGTH)
|
237 |
|
238 |
+
return default_kwargs
|
239 |
+
|
240 |
+
def _compute_audio_embed_size(self, audio_frames: int) -> int:
|
241 |
+
result = math.ceil(audio_frames / self.compression_rate)
|
242 |
+
return math.ceil(result / self.qformer_compression_rate)
|
243 |
|
244 |
def __call__(
|
245 |
self,
|
246 |
+
images=None,
|
247 |
+
text=None,
|
248 |
+
videos=None,
|
249 |
+
audio=None,
|
250 |
+
**kwargs: Unpack[Gemma3ProcessorKwargs]
|
|
|
251 |
) -> BatchFeature:
|
252 |
+
if text is None and images is None:
|
253 |
+
raise ValueError("Provide at least one of `text` or `images`.")
|
254 |
|
255 |
+
output_kwargs = self._merge_kwargs(
|
256 |
+
Gemma3ProcessorKwargs,
|
257 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
|
|
258 |
**kwargs
|
259 |
)
|
260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
if isinstance(text, str):
|
262 |
text = [text]
|
263 |
+
elif not isinstance(text, list) or not all(isinstance(t, str) for t in text):
|
264 |
+
raise ValueError("Input text must be a string or list of strings")
|
265 |
|
266 |
+
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", "pt")
|
267 |
+
image_inputs = {}
|
268 |
if images is not None:
|
|
|
|
|
269 |
batched_images = make_nested_list_of_images(images)
|
270 |
+
image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
|
272 |
+
if not text:
|
273 |
+
text = [" ".join([self.boi_token] * len(images)) for images in batched_images]
|
274 |
|
275 |
+
if len(batched_images) != len(text):
|
276 |
raise ValueError(
|
277 |
+
f"Inconsistent batch sizes: {len(batched_images)} images, {len(text)} texts"
|
278 |
+
)
|
279 |
+
|
280 |
+
num_crops = to_py_obj(image_inputs.pop("num_crops"))
|
281 |
+
batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images]
|
282 |
+
|
283 |
+
for batch_idx, (prompt, images, crops) in enumerate(zip(text, batched_images, batch_num_crops)):
|
284 |
+
image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
|
285 |
+
if len(images) != len(image_indexes):
|
286 |
+
raise ValueError(
|
287 |
+
f"Prompt has {len(image_indexes)} image tokens but received {len(images)} images"
|
288 |
+
)
|
289 |
+
|
290 |
+
for num, idx in reversed(list(zip(crops, image_indexes))):
|
291 |
+
if num:
|
292 |
+
formatted_image_text = (
|
293 |
+
f"Here is the original image {self.boi_token} and here are some crops to help you see better "
|
294 |
+
+ " ".join([self.boi_token] * num)
|
295 |
+
)
|
296 |
+
prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token):]
|
297 |
+
text[batch_idx] = prompt
|
298 |
+
|
299 |
+
text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
|
300 |
+
|
301 |
+
audio_inputs = {}
|
302 |
+
if audio is not None:
|
303 |
+
audio_inputs = self.audio_processor(audio, return_tensors)
|
304 |
+
audio_embeds = audio_inputs['audio_values']
|
305 |
+
audio_frames = audio_embeds.shape[1] * self.feat_stride
|
306 |
+
audio_seq_length = self._compute_audio_embed_size(audio_frames)
|
307 |
+
|
308 |
+
audio_tokens = {
|
309 |
+
"boa_token": "<start_of_audio>",
|
310 |
+
"eoa_token": "<end_of_audio>",
|
311 |
+
"audio_token": "<audio_soft_token>",
|
312 |
+
"boa_token_id": 256001,
|
313 |
+
"eoa_token_id": 256002,
|
314 |
+
"audio_token_id": self.audio_token_id # Use dynamic ID
|
315 |
+
}
|
316 |
+
|
317 |
+
audio_sequence = f"\n\n{audio_tokens['boa_token']}{''.join([audio_tokens['audio_token']] * audio_seq_length)}{audio_tokens['eoa_token']}\n\n"
|
318 |
+
text = [prompt.replace(audio_tokens['boa_token'], audio_sequence) for prompt in text]
|
319 |
+
|
320 |
+
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors=return_tensors)
|
321 |
+
|
322 |
+
# Debug: Log text and token counts before validation
|
323 |
+
for i, (txt, ids) in enumerate(zip(text, text_inputs["input_ids"])):
|
324 |
+
audio_text_count = txt.count(self.audio_token)
|
325 |
+
audio_ids_count = list(ids).count(self.audio_token_id)
|
326 |
logger.debug(
|
327 |
+
f"Sample {i}: Audio tokens in text={audio_text_count}, in input_ids={audio_ids_count}, "
|
328 |
+
f"Text length={len(txt)}, Input IDs length={len(ids)}"
|
|
|
329 |
)
|
330 |
|
331 |
+
array_ids = text_inputs["input_ids"]
|
332 |
+
if return_tensors == "pt":
|
333 |
+
mm_token_type_ids = torch.zeros_like(array_ids)
|
334 |
+
else:
|
335 |
+
mm_token_type_ids = np.zeros_like(array_ids)
|
336 |
+
mm_token_type_ids[array_ids == self.image_token_id] = 1 # Image token type
|
337 |
+
mm_token_type_ids[array_ids == self.audio_token_id] = 2 # Audio token type
|
338 |
+
text_inputs["token_type_ids"] = mm_token_type_ids
|
339 |
+
|
340 |
+
return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
|
342 |
def batch_decode(self, *args, **kwargs):
|
343 |
return self.tokenizer.batch_decode(*args, **kwargs)
|
|
|
347 |
|
348 |
@property
|
349 |
def model_input_names(self):
|
350 |
+
tokenizer_inputs = self.tokenizer.model_input_names + ["token_type_ids"]
|
351 |
+
image_processor_inputs = self.image_processor.model_input_names
|
352 |
+
audio_processor_inputs = self.audio_processor.model_input_names
|
353 |
+
return list(dict.fromkeys(tokenizer_inputs + image_processor_inputs + audio_processor_inputs))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|