BAAI
/

3v324v23 commited on
Commit
e3bf9ba
·
1 Parent(s): b40f27d
Files changed (28) hide show
  1. README.md +50 -8
  2. llava_arch.py +43 -34
  3. llava_qwen.py +6 -7
  4. multimodal_encoder/.ipynb_checkpoints/base_encoder-checkpoint.py +68 -0
  5. multimodal_encoder/.ipynb_checkpoints/builder-checkpoint.py +29 -0
  6. multimodal_encoder/.ipynb_checkpoints/clip_encoder-checkpoint.py +179 -0
  7. multimodal_encoder/.ipynb_checkpoints/siglip_encoder-checkpoint.py +151 -0
  8. multimodal_encoder/__pycache__/base_encoder.cpython-310.pyc +0 -0
  9. multimodal_encoder/__pycache__/builder.cpython-310.pyc +0 -0
  10. multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc +0 -0
  11. multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc +0 -0
  12. multimodal_encoder/base_encoder.py +68 -0
  13. multimodal_encoder/builder.py +20 -0
  14. multimodal_encoder/siglip_encoder.py +154 -0
  15. multimodal_projector/__pycache__/builder.cpython-310.pyc +0 -0
  16. multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc +0 -0
  17. multimodal_projector/builder.py +65 -0
  18. multimodal_projector/pooler_projector.py +33 -0
  19. multimodal_resampler/__pycache__/builder.cpython-310.pyc +0 -0
  20. multimodal_resampler/__pycache__/masked_drop.cpython-310.pyc +0 -0
  21. multimodal_resampler/__pycache__/perceiver.cpython-310.pyc +0 -0
  22. multimodal_resampler/__pycache__/qformer.cpython-310.pyc +0 -0
  23. multimodal_resampler/__pycache__/spatial_pool.cpython-310.pyc +0 -0
  24. multimodal_resampler/builder.py +34 -0
  25. multimodal_resampler/masked_drop.py +80 -0
  26. multimodal_resampler/perceiver.py +155 -0
  27. multimodal_resampler/qformer.py +1160 -0
  28. multimodal_resampler/spatial_pool.py +45 -0
README.md CHANGED
@@ -21,14 +21,13 @@ Video-XL-2 supply two efficiency optimization strategy: chunk-based prefill and
21
 
22
  TODO
23
  - [X] Release model weights.
24
- - [ ] Release the inference code w/o. efficiency optimization.
25
- - [ ] Release the inference code w. chunk-based prefill.
26
  - [ ] Release the inference code w. chunk-based prefill & bi-level kvs decoding.
27
 
28
  **Tips: Our inference code still under updating, you could update it by assign "--include '\*.py'" in huggingface-cli to only update the inference code, avoid downloading the whole model.*
29
 
30
  ---
31
-
32
  ### w/o. efficiency optimization
33
  ```python
34
  from transformers import AutoTokenizer, AutoModel, AutoConfig, BitsAndBytesConfig
@@ -38,7 +37,7 @@ import torch
38
  model_path = '/root/Models/Video-XL-2'
39
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
40
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
41
- model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map=device,quantization_config=None,attn_implementation="sdpa").to(torch.bfloat16)
42
 
43
  gen_kwargs = {
44
  "do_sample": True,
@@ -74,17 +73,60 @@ To enable this mode, you need to set `enable_chunk_prefill` to `True` and config
74
  * **`chunk_prefill_mode`**: This defines the mode of chunk-based prefill. We currently support two modes:
75
  * **`streaming`**: This mode encodes video chunks streamingly.
76
  * **`mask`**: This mode achieves an equivalent effect using an attention mask. However, due to a lack of underlying optimized operators, the `mask` mode doesn't offer any efficiency improvements at this time. We recommend using the `streaming` mode.
77
- * **`chunk_size`**: This parameter specifies the size of each chunk processed in a single forward pass. The unit for `chunk_size` is 4 frames (e.g., `chunk_size = 4` means processing visual tokens from $4 \times 4 = 16$ frames at once). A larger `chunk_size` will gradually approach full attention, resulting in a higher peak memory usage.
78
  * **`step_size`**: This controls the step size between chunks. A smaller `step_size` leads to more continuous information transfer between chunks but may slightly decrease inference speed.
79
  * **`offload`**: This boolean parameter determines whether to offload the key-value states (KVs) of each chunk to the CPU during forwarding. While this can reduce memory usage, it will also lower the inference speed.
80
- * **`chunk_size_for_vision_tower`**: For longer video inputs, the vision tower can become a memory bottleneck during forwarding. To mitigate this, we also support a streaming mode for the vision tower, which is controlled by this parameter.
81
 
82
  **Tip: Currently, chunk-based prefill only supports the 'sdpa' attention implementation.*
83
 
84
- ---
85
-
86
  ```python
 
 
 
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  ```
89
 
90
  ---
 
21
 
22
  TODO
23
  - [X] Release model weights.
24
+ - [X] Release the inference code w/o. efficiency optimization.
25
+ - [X] Release the inference code w. chunk-based prefill.
26
  - [ ] Release the inference code w. chunk-based prefill & bi-level kvs decoding.
27
 
28
  **Tips: Our inference code still under updating, you could update it by assign "--include '\*.py'" in huggingface-cli to only update the inference code, avoid downloading the whole model.*
29
 
30
  ---
 
31
  ### w/o. efficiency optimization
32
  ```python
33
  from transformers import AutoTokenizer, AutoModel, AutoConfig, BitsAndBytesConfig
 
37
  model_path = '/root/Models/Video-XL-2'
38
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
39
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
40
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map=device,quantization_config=None,attn_implementation="sdpa",torch_dtype=torch.bfloat16)
41
 
42
  gen_kwargs = {
43
  "do_sample": True,
 
73
  * **`chunk_prefill_mode`**: This defines the mode of chunk-based prefill. We currently support two modes:
74
  * **`streaming`**: This mode encodes video chunks streamingly.
75
  * **`mask`**: This mode achieves an equivalent effect using an attention mask. However, due to a lack of underlying optimized operators, the `mask` mode doesn't offer any efficiency improvements at this time. We recommend using the `streaming` mode.
76
+ * **`chunk_size`**: This parameter specifies the size of each chunk processed in a single forward pass. The unit for `chunk_size` is **4 frames** (e.g., `chunk_size = 4` means processing visual tokens from **4×4 = 16 frames** at once). A larger `chunk_size` will gradually approach full attention, resulting in a higher peak memory usage.
77
  * **`step_size`**: This controls the step size between chunks. A smaller `step_size` leads to more continuous information transfer between chunks but may slightly decrease inference speed.
78
  * **`offload`**: This boolean parameter determines whether to offload the key-value states (KVs) of each chunk to the CPU during forwarding. While this can reduce memory usage, it will also lower the inference speed.
79
+ * **`chunk_size_for_vision_tower`**: For longer video inputs, the vision tower can become a memory bottleneck during forwarding. To mitigate this, we also support a streaming mode for the vision tower, which is controlled by this parameter. The unit for `chunk_size_for_vision_tower` is **1 frames**. And, the value of `chunk_size_for_vision_tower` must be **a multiple of 4**.
80
 
81
  **Tip: Currently, chunk-based prefill only supports the 'sdpa' attention implementation.*
82
 
 
 
83
  ```python
84
+ from transformers import AutoTokenizer, AutoModel, AutoConfig, BitsAndBytesConfig
85
+ import torch
86
+ import pdb
87
+ import argparse
88
 
89
+ torch.cuda.reset_peak_memory_stats()
90
+ # load model
91
+ model_path = '/share/minghao/Models/Video-XL-2'
92
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
93
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
94
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map=device,quantization_config=None,attn_implementation="sdpa",torch_dtype=torch.bfloat16)
95
+
96
+ gen_kwargs = {"do_sample": False, "temperature": 0.01, "top_p": 0.001, "num_beams": 1, "use_cache": True, "max_new_tokens": 128}
97
+
98
+
99
+ """
100
+ Set params
101
+ With Chunk-based Prefill enabled, Video-XL-2 can process 1,300 frames on a 24GB GPU (using approximately 23.72GB). When combined with bi-level KVS decoding, this capacity increases to 1,800 frames.
102
+ If you have ample resources, you can disable offload and increase chunk_size_for_vision_tower and chunk_size to achieve faster processing.
103
+ """
104
+ model.config.enable_chunk_prefill = True
105
+ prefill_config = {
106
+ 'chunk_prefill_mode': 'streaming',
107
+ 'chunk_size': 4,
108
+ 'step_size': 1,
109
+ 'offload': True,
110
+ 'chunk_size_for_vision_tower': 24,
111
+ }
112
+ model.config.prefill_config = prefill_config
113
+
114
+ # input data
115
+ video_path = "/share/LXRlxr0_0/code/videoxl2/lmm-eval/~/.cache/huggingface/videomme/ZBKUqc_ICpg.mp4"
116
+ question1 = "How many people in the video? (A)3 people (B)6 people. Please only respone the letter"
117
+
118
+ # params
119
+ max_num_frames = 1300
120
+ sample_fps = None # extract frame at 1fps
121
+ max_sample_fps = None
122
+
123
+ with torch.inference_mode():
124
+ response = model.chat(video_path, tokenizer, question1, chat_history=None, return_history=False,max_num_frames=max_num_frames, sample_fps=sample_fps, max_sample_fps=max_sample_fps, generation_config=gen_kwargs)
125
+
126
+
127
+ peak_memory_allocated = torch.cuda.max_memory_allocated()
128
+ print(f"Memory Peak: {peak_memory_allocated / (1024**3):.2f} GB") # 23.72GB
129
+ print(response)
130
  ```
131
 
132
  ---
llava_arch.py CHANGED
@@ -12,7 +12,6 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
-
16
  from abc import ABC, abstractmethod
17
  import importlib.util
18
  import os.path as osp
@@ -23,30 +22,29 @@ import torch
23
  import torch.nn as nn
24
  import torch.nn.functional as F
25
 
26
- try:
27
- from .builder import build_vision_tower
28
- from .builder import build_vision_resampler
29
- from .builder import build_vision_projector
30
- except ModuleNotFoundError:
31
- spec = importlib.util.spec_from_file_location(
32
- "builder",
33
- osp.join(osp.dirname(__file__), "builder.py"),
34
- )
35
- builder = importlib.util.module_from_spec(spec)
36
- spec.loader.exec_module(builder)
37
- build_vision_tower = getattr(
38
- builder,
39
- "build_vision_tower",
40
- )
41
- build_vision_resampler = getattr(
42
- builder,
43
- "build_vision_resampler",
44
- )
45
- build_vision_projector = getattr(
46
- builder,
47
- "build_vision_projector",
48
- )
49
-
50
 
51
  from transformers import AutoTokenizer
52
 
@@ -59,6 +57,7 @@ from .sae import SiglipAE
59
  import numpy as np
60
  import torch.nn.functional as F
61
  import pdb
 
62
  class LlavaMetaModel:
63
 
64
  def __init__(self, config):
@@ -304,18 +303,22 @@ class LlavaMetaForCausalLM(ABC):
304
  return expanded_x
305
 
306
  def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None):
307
- pdb.set_trace()
308
  if self.config.enable_chunk_prefill:
309
  chunk_size_for_vision_tower = self.config.prefill_config['chunk_size_for_vision_tower']
310
  else:
311
  chunk_size_for_vision_tower = 100000
 
312
  # Define the maximum batch size (1024 frames)
313
  max_batch_size = chunk_size_for_vision_tower
 
314
  num_frames = videos_or_images.shape[0]
315
  # Initialize a list to store the features from each batch
316
  videos_or_images_features = []
317
 
 
 
318
  # Split videos_or_images into smaller batches if num_frames > max_batch_size
 
319
  if num_frames > max_batch_size:
320
  # Calculate the number of batches needed
321
  num_batches = (num_frames + max_batch_size - 1) // max_batch_size
@@ -326,23 +329,29 @@ class LlavaMetaForCausalLM(ABC):
326
  # Process each batch separately
327
  batch_videos_or_images = videos_or_images[start_idx:end_idx]
328
  batch_features = self.get_model().get_vision_tower()(batch_videos_or_images)
329
- videos_or_images_features.append(batch_features)
330
 
 
 
 
 
 
 
331
  # Concatenate the features of all batches
332
- videos_or_images_features = torch.cat(videos_or_images_features, dim=0)
333
  else:
334
  videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images)
335
 
336
  per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0)
337
  all_videos_or_images_features = []
338
 
339
- peak_memory_allocated = torch.cuda.max_memory_allocated()
340
- print(f"vision encoder 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
341
-
342
  del videos_or_images_features
343
  torch.cuda.empty_cache()
344
 
345
  chunk_size = chunk_size_for_vision_tower
 
346
  all_feat_list = []
347
  for idx, feat in enumerate(per_videos_or_images_features):
348
  for i in range(0, feat.shape[0], chunk_size):
@@ -365,8 +374,8 @@ class LlavaMetaForCausalLM(ABC):
365
  all_feat_list.append(batched_feat)
366
 
367
  feat = torch.cat(all_feat_list, dim=0)
368
- peak_memory_allocated = torch.cuda.max_memory_allocated()
369
- print(f"sae 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
370
 
371
  del per_videos_or_images_features
372
  del all_feat_list
@@ -406,7 +415,7 @@ class LlavaMetaForCausalLM(ABC):
406
  return image_features
407
 
408
  def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None,time_embedding=None):
409
- pdb.set_trace()
410
  vision_tower = self.get_vision_tower()
411
  if vision_tower is None or images is None or input_ids.shape[1] == 1:
412
  return input_ids, position_ids, attention_mask, past_key_values, None, labels
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
15
  from abc import ABC, abstractmethod
16
  import importlib.util
17
  import os.path as osp
 
22
  import torch.nn as nn
23
  import torch.nn.functional as F
24
 
25
+ # try:
26
+ from .multimodal_encoder.builder import build_vision_tower
27
+ from .multimodal_resampler.builder import build_vision_resampler
28
+ from .multimodal_projector.builder import build_vision_projector
29
+ # except ModuleNotFoundError:
30
+ # spec = importlib.util.spec_from_file_location(
31
+ # "builder",
32
+ # osp.join(osp.dirname(__file__), "builder.py"),
33
+ # )
34
+ # builder = importlib.util.module_from_spec(spec)
35
+ # spec.loader.exec_module(builder)
36
+ # build_vision_tower = getattr(
37
+ # builder,
38
+ # "build_vision_tower",
39
+ # )
40
+ # build_vision_resampler = getattr(
41
+ # builder,
42
+ # "build_vision_resampler",
43
+ # )
44
+ # build_vision_projector = getattr(
45
+ # builder,
46
+ # "build_vision_projector",
47
+ # )
 
48
 
49
  from transformers import AutoTokenizer
50
 
 
57
  import numpy as np
58
  import torch.nn.functional as F
59
  import pdb
60
+
61
  class LlavaMetaModel:
62
 
63
  def __init__(self, config):
 
303
  return expanded_x
304
 
305
  def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None):
 
306
  if self.config.enable_chunk_prefill:
307
  chunk_size_for_vision_tower = self.config.prefill_config['chunk_size_for_vision_tower']
308
  else:
309
  chunk_size_for_vision_tower = 100000
310
+ # pdb.set_trace()
311
  # Define the maximum batch size (1024 frames)
312
  max_batch_size = chunk_size_for_vision_tower
313
+ # print(f'max_batch_size: {max_batch_size}')
314
  num_frames = videos_or_images.shape[0]
315
  # Initialize a list to store the features from each batch
316
  videos_or_images_features = []
317
 
318
+ videos_or_images_features = torch.empty((num_frames, 729, 1152), device=self.get_model().device, dtype=self.get_model().dtype)
319
+
320
  # Split videos_or_images into smaller batches if num_frames > max_batch_size
321
+ current_idx = 0
322
  if num_frames > max_batch_size:
323
  # Calculate the number of batches needed
324
  num_batches = (num_frames + max_batch_size - 1) // max_batch_size
 
329
  # Process each batch separately
330
  batch_videos_or_images = videos_or_images[start_idx:end_idx]
331
  batch_features = self.get_model().get_vision_tower()(batch_videos_or_images)
332
+ # videos_or_images_features.append(batch_features)
333
 
334
+ videos_or_images_features[current_idx:current_idx + batch_features.shape[0]] = batch_features
335
+ # Update the current index for the next batch
336
+ current_idx += batch_features.shape[0]
337
+ # peak_memory_allocated = torch.cuda.max_memory_allocated()
338
+ # print(f"vision encoder 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
339
+
340
  # Concatenate the features of all batches
341
+ # videos_or_images_features = torch.cat(videos_or_images_features, dim=0)
342
  else:
343
  videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images)
344
 
345
  per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0)
346
  all_videos_or_images_features = []
347
 
348
+ # peak_memory_allocated = torch.cuda.max_memory_allocated()
349
+ # print(f"vision encoder 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
 
350
  del videos_or_images_features
351
  torch.cuda.empty_cache()
352
 
353
  chunk_size = chunk_size_for_vision_tower
354
+ print(f'chunk_size: {chunk_size}')
355
  all_feat_list = []
356
  for idx, feat in enumerate(per_videos_or_images_features):
357
  for i in range(0, feat.shape[0], chunk_size):
 
374
  all_feat_list.append(batched_feat)
375
 
376
  feat = torch.cat(all_feat_list, dim=0)
377
+ # peak_memory_allocated = torch.cuda.max_memory_allocated()
378
+ # print(f"sae 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
379
 
380
  del per_videos_or_images_features
381
  del all_feat_list
 
415
  return image_features
416
 
417
  def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None,time_embedding=None):
418
+
419
  vision_tower = self.get_vision_tower()
420
  if vision_tower is None or images is None or input_ids.shape[1] == 1:
421
  return input_ids, position_ids, attention_mask, past_key_values, None, labels
llava_qwen.py CHANGED
@@ -22,6 +22,7 @@ from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaMod
22
  from transformers.modeling_outputs import CausalLMOutputWithPast
23
  from transformers.generation.utils import GenerateOutput
24
  from .llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
 
25
  from .modeling_qwen2 import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
26
  import pdb
27
  import time
@@ -375,9 +376,8 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
375
  values = torch.cat([pkv[1].to(device=device) for pkv in layer_pkvs], dim=2)
376
  merged_pkv.append((keys, values))
377
 
378
- peak_memory_allocated = torch.cuda.max_memory_allocated()
379
- print(f"prefill 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
380
-
381
 
382
  pkv = merged_pkv
383
  del block_streaming_past_key_values
@@ -392,7 +392,6 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
392
  prefill_len = visual_token_end_pos
393
 
394
  # torch.cuda.reset_peak_memory_stats()
395
-
396
  # Process suffix
397
  if suffix_embeds.size(1) > 0:
398
  seq_len = suffix_embeds.size(1)
@@ -413,8 +412,8 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
413
  return_dict=return_dict,
414
  # blocks_positions=None,
415
  )
416
- peak_memory_allocated = torch.cuda.max_memory_allocated()
417
- print(f"decoding 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
418
  del mixed_prefill_past_key_values
419
  torch.cuda.empty_cache()
420
 
@@ -650,7 +649,6 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
650
  sample_fps=1,
651
  max_sample_fps=4,
652
  generation_config={}):
653
- pdb.set_trace()
654
 
655
  # prepare text input
656
  conv = conv_templates["qwen_1_5"].copy()
@@ -668,6 +666,7 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
668
 
669
  # prepare video input
670
  frames, timestamps = load_video(video_path, max_num_frames, fps=sample_fps, max_fps=max_sample_fps)
 
671
 
672
  time_stamps=[]
673
  token_frames_sum=(len(timestamps)+3)//4
 
22
  from transformers.modeling_outputs import CausalLMOutputWithPast
23
  from transformers.generation.utils import GenerateOutput
24
  from .llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
25
+ # from longva.longva.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
26
  from .modeling_qwen2 import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
27
  import pdb
28
  import time
 
376
  values = torch.cat([pkv[1].to(device=device) for pkv in layer_pkvs], dim=2)
377
  merged_pkv.append((keys, values))
378
 
379
+ # peak_memory_allocated = torch.cuda.max_memory_allocated()
380
+ # print(f"prefill 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
 
381
 
382
  pkv = merged_pkv
383
  del block_streaming_past_key_values
 
392
  prefill_len = visual_token_end_pos
393
 
394
  # torch.cuda.reset_peak_memory_stats()
 
395
  # Process suffix
396
  if suffix_embeds.size(1) > 0:
397
  seq_len = suffix_embeds.size(1)
 
412
  return_dict=return_dict,
413
  # blocks_positions=None,
414
  )
415
+ # peak_memory_allocated = torch.cuda.max_memory_allocated()
416
+ # print(f"decoding 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
417
  del mixed_prefill_past_key_values
418
  torch.cuda.empty_cache()
419
 
 
649
  sample_fps=1,
650
  max_sample_fps=4,
651
  generation_config={}):
 
652
 
653
  # prepare text input
654
  conv = conv_templates["qwen_1_5"].copy()
 
666
 
667
  # prepare video input
668
  frames, timestamps = load_video(video_path, max_num_frames, fps=sample_fps, max_fps=max_sample_fps)
669
+ print(f'video has loaded, extratc {len(frames)} frames.')
670
 
671
  time_stamps=[]
672
  token_frames_sum=(len(timestamps)+3)//4
multimodal_encoder/.ipynb_checkpoints/base_encoder-checkpoint.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class BaseVisionTower(nn.Module):
8
+ def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower_name
14
+ self.delay_load = delay_load
15
+
16
+ @abstractmethod
17
+ def load_model(self, device_map=None):
18
+ raise NotImplementedError("Subclasses must implement load_model")
19
+
20
+ @abstractmethod
21
+ def _forward(self, images):
22
+ raise NotImplementedError("Subclasses must implement forward")
23
+
24
+ def forward(self, images):
25
+ if type(images) is list:
26
+ image_features = [self._forward(image.unsqueeze(0)) for image in images]
27
+ else:
28
+ image_features = self._forward(images)
29
+
30
+ return image_features
31
+
32
+ @property
33
+ def dummy_feature(self):
34
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
35
+
36
+ @property
37
+ def dtype(self):
38
+ # Dynamically infer the dtype from the first parameter, if not explicitly specified
39
+ if hasattr(self.vision_tower, "dtype"):
40
+ return self.vision_tower.dtype
41
+ else:
42
+ params = list(self.vision_tower.parameters())
43
+ return (
44
+ params[0].dtype if len(params) > 0 else torch.float32
45
+ ) # Default to torch.float32 if no parameters
46
+
47
+ @property
48
+ def device(self):
49
+ # Dynamically infer the device from the first parameter, if not explicitly specified
50
+ if hasattr(self.vision_tower, "device"):
51
+ return self.vision_tower.device
52
+ else:
53
+ params = list(self.vision_tower.parameters())
54
+ return (
55
+ params[0].device if len(params) > 0 else torch.device("cpu")
56
+ ) # Default to CPU if no parameters
57
+ @property
58
+ def config(self):
59
+ if self.is_loaded:
60
+ return self.vision_tower.config
61
+ else:
62
+ return self.cfg_only
63
+ @property
64
+ def hidden_size(self):
65
+ try:
66
+ return self.config.hidden_size
67
+ except:
68
+ return self._hidden_size
multimodal_encoder/.ipynb_checkpoints/builder-checkpoint.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
3
+ from .siglip_encoder import SigLipVisionTower
4
+ # from .eva_clip.eva_clip_encoder import EvaClipVisionTower
5
+ # from .dev_eva_clip.eva_vit import EvaViTWrapper
6
+
7
+
8
+ def build_vision_tower(vision_tower_cfg, **kwargs):
9
+
10
+ vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
11
+ is_absolute_path_exists = os.path.exists(vision_tower)
12
+ use_s2 = getattr(vision_tower_cfg, "s2", False)
13
+
14
+ #print(getattr(vision_tower_cfg, "vision_tower", None))
15
+ return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
16
+ if getattr(vision_tower_cfg, "vision_tower", None) and "siglip" in getattr(vision_tower_cfg, "vision_tower", None).lower():
17
+ #print('*************\n')
18
+ return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
19
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
20
+ if use_s2:
21
+ return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
22
+ else:
23
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
24
+ # elif "internal-eva" in vision_tower.lower() or "eva02" in vision_tower.lower():
25
+ # return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
26
+ # elif vision_tower in ["EVA-CLIP-8B", "EVA-CLIP-8B-plus"]:
27
+ # return EvaViTWrapper(vision_tower, args=vision_tower_cfg, **kwargs)
28
+
29
+ raise ValueError(f"Unknown vision tower: {vision_tower}")
multimodal_encoder/.ipynb_checkpoints/clip_encoder-checkpoint.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from longva.longva.utils import rank0_print
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+ try:
7
+ from s2wrapper import forward as multiscale_forward
8
+ except:
9
+ pass
10
+
11
+
12
+ class CLIPVisionTower(nn.Module):
13
+ def __init__(self, vision_tower, args, delay_load=False):
14
+ super().__init__()
15
+
16
+ self.is_loaded = False
17
+
18
+ self.vision_tower_name = vision_tower
19
+ self.select_layer = args.mm_vision_select_layer
20
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
21
+
22
+ if not delay_load:
23
+ rank0_print(f"Loading vision tower: {vision_tower}")
24
+ self.load_model()
25
+ elif getattr(args, "unfreeze_mm_vision_tower", False):
26
+ # TODO: better detector is needed.
27
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
28
+ self.load_model()
29
+ elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
30
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
31
+ self.load_model()
32
+ else:
33
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
34
+
35
+ def load_model(self, device_map=None):
36
+ if self.is_loaded:
37
+ rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
38
+ return
39
+
40
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
41
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
42
+ self.vision_tower.requires_grad_(False)
43
+
44
+ self.is_loaded = True
45
+
46
+ def feature_select(self, image_forward_outs):
47
+ select_feature_type = self.select_feature
48
+
49
+ if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
50
+ select_every_k_layer = len(image_forward_outs.hidden_states) // 4
51
+ image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1)
52
+ select_feature_type = select_feature_type.replace("slicefour_", "")
53
+ elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
54
+ select_layers = [-2, -5, -8, -11, 6]
55
+ image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1)
56
+ select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
57
+ else:
58
+ image_features = image_forward_outs.hidden_states[self.select_layer]
59
+
60
+ if select_feature_type == "patch":
61
+ image_features = image_features[:, 1:]
62
+ elif select_feature_type == "cls_patch":
63
+ image_features = image_features
64
+ else:
65
+ raise ValueError(f"Unexpected select feature: {select_feature_type}")
66
+ return image_features
67
+
68
+ def forward(self, images):
69
+ if type(images) is list:
70
+ image_features = []
71
+ for image in images:
72
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
73
+ #print('image_feature before select ',image_forward_out.hidden_states[-1].shape)
74
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
75
+ #print('image_feature after select ',image_feature.shape)
76
+ image_features.append(image_feature)
77
+ else:
78
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
79
+ #print('image_feature before select ',image_forward_outs.hidden_states[-1].shape)
80
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
81
+ #print('image_feature after select ',image_features.shape)
82
+
83
+ return image_features
84
+
85
+ @property
86
+ def dummy_feature(self):
87
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
88
+
89
+ @property
90
+ def dtype(self):
91
+ return self.vision_tower.dtype
92
+
93
+ @property
94
+ def device(self):
95
+ return self.vision_tower.device
96
+
97
+ @property
98
+ def config(self):
99
+ if self.is_loaded:
100
+ return self.vision_tower.config
101
+ else:
102
+ return self.cfg_only
103
+
104
+ @property
105
+ def hidden_size(self):
106
+ _hidden_size = self.config.hidden_size
107
+ if "slicefour" in self.select_feature:
108
+ _hidden_size *= 4
109
+ if "slice_m25811_f6" in self.select_feature:
110
+ _hidden_size *= 5
111
+ return _hidden_size
112
+
113
+ @property
114
+ def num_patches_per_side(self):
115
+ return self.config.image_size // self.config.patch_size
116
+
117
+ @property
118
+ def num_patches(self):
119
+ _num_patches = (self.config.image_size // self.config.patch_size) ** 2
120
+ if "cls_patch" in self.select_feature:
121
+ _num_patches += 1
122
+ return _num_patches
123
+
124
+ @property
125
+ def image_size(self):
126
+ return self.config.image_size
127
+
128
+
129
+ class CLIPVisionTowerS2(CLIPVisionTower):
130
+ def __init__(self, vision_tower, args, delay_load=False):
131
+
132
+ self.s2_scales = getattr(args, "s2_scales", "336,672,1008")
133
+ self.s2_scales = list(map(int, self.s2_scales.split(",")))
134
+ self.s2_scales.sort()
135
+ self.s2_split_size = self.s2_scales[0]
136
+ self.s2_image_size = self.s2_scales[-1]
137
+
138
+ super().__init__(vision_tower, args, delay_load)
139
+
140
+ # change resize/crop size in preprocessing to the largest image size in s2_scale
141
+ if not delay_load or getattr(args, "unfreeze_mm_vision_tower", False):
142
+ self.image_processor.size["shortest_edge"] = self.s2_image_size
143
+ self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size
144
+
145
+ def load_model(self, device_map=None):
146
+ if self.is_loaded:
147
+ rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
148
+ return
149
+
150
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
151
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
152
+ self.vision_tower.requires_grad_(False)
153
+
154
+ self.image_processor.size["shortest_edge"] = self.s2_image_size
155
+ self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size
156
+
157
+ self.is_loaded = True
158
+
159
+ @torch.no_grad()
160
+ def forward_feature(self, images):
161
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
162
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
163
+ return image_features
164
+
165
+ @torch.no_grad()
166
+ def forward(self, images):
167
+ if type(images) is list:
168
+ image_features = []
169
+ for image in images:
170
+ image_feature = multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True)
171
+ image_features.append(image_feature)
172
+ else:
173
+ image_features = multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True)
174
+
175
+ return image_features
176
+
177
+ @property
178
+ def hidden_size(self):
179
+ return self.config.hidden_size * len(self.s2_scales)
multimodal_encoder/.ipynb_checkpoints/siglip_encoder-checkpoint.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from typing import Optional, Tuple, Union, Dict
5
+ from PIL import Image
6
+ from functools import partial, reduce
7
+ from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
8
+
9
+ from .base_encoder import BaseVisionTower
10
+ import torch.distributed as dist
11
+ # --data_path /share/shuyan/video_traindata/anno/\{cinepine_order\}.json \
12
+ # --image_folder /share/shuyan/video_traindata/Bunny-v1_0-data/finetune/images \
13
+ # --video_folder /share/shuyan/video_traindata \
14
+ def rank0_print(*args):
15
+ if dist.is_initialized():
16
+ if dist.get_rank() == 0:
17
+ print(f"Rank {dist.get_rank()}: ", *args)
18
+ else:
19
+ print(*args)
20
+
21
+
22
+ from transformers.image_processing_utils import BatchFeature, get_size_dict
23
+ from transformers.image_transforms import (
24
+ convert_to_rgb,
25
+ normalize,
26
+ rescale,
27
+ resize,
28
+ to_channel_dimension_format,
29
+ )
30
+ from transformers.image_utils import (
31
+ ChannelDimension,
32
+ PILImageResampling,
33
+ to_numpy_array,
34
+ )
35
+ class SigLipImageProcessor:
36
+ def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST):
37
+ crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
38
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
39
+
40
+ self.image_mean = image_mean
41
+ self.image_std = image_std
42
+ self.size = size
43
+ self.resample = resample
44
+ self.rescale_factor = rescale_factor
45
+ self.data_format = data_format
46
+ self.crop_size = crop_size
47
+
48
+ def preprocess(self, images, return_tensors):
49
+ if isinstance(images, Image.Image):
50
+ images = [images]
51
+ else:
52
+ # to adapt video data
53
+ images = [to_numpy_array(image) for image in images]
54
+ assert isinstance(images, list)
55
+
56
+ transforms = [
57
+ convert_to_rgb,
58
+ to_numpy_array,
59
+ partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
60
+ partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
61
+ partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
62
+ partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
63
+ ]
64
+
65
+ images = reduce(lambda x, f: [*map(f, x)], transforms, images)
66
+
67
+ data = {"pixel_values": images}
68
+
69
+ return BatchFeature(data=data, tensor_type=return_tensors)
70
+
71
+ class SigLipVisionTower(BaseVisionTower):
72
+ def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False):
73
+ super(SigLipVisionTower, self).__init__(vision_tower_name, vision_tower_cfg, delay_load)
74
+
75
+ model_path = "google/siglip-so400m-patch14-384"
76
+ base_model_name, res, interp = model_path, 384, 576
77
+ self.vision_tower_name = base_model_name
78
+ self._image_size = res if res is not None else 512
79
+ self.unfreeze_mm_vision_tower = getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False)
80
+
81
+ if not delay_load:
82
+ rank0_print(f"Loading vision tower: {vision_tower_name}")
83
+ self.load_model()
84
+ elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False):
85
+ # TODO: better detector is needed.
86
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
87
+ self.load_model()
88
+ elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts:
89
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
90
+ self.load_model()
91
+ else:
92
+ self.cfg_only = self.config
93
+
94
+ def load_model(self, device_map=None):
95
+ self.vision_model = "siglip"
96
+ # clip_model, processor = create_model_from_pretrained(self.vision_tower_name)
97
+ self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
98
+
99
+ # self.vision_tower = clip_model.visual.trunk
100
+ self.vision_tower.output_tokens = True
101
+
102
+ self._hidden_size = self.vision_tower.config.hidden_size
103
+
104
+ self.image_processor = SigLipImageProcessor()
105
+
106
+ del self.vision_tower.vision_model.encoder.layers[-1:]
107
+ self.vision_tower.vision_model.head = nn.Identity()
108
+
109
+ self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
110
+ self.is_loaded = True
111
+
112
+ def _forward(self, images):
113
+ with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
114
+ image_features = self.vision_tower.forward(
115
+ images.to(device=self.device, dtype=self.dtype),
116
+ output_hidden_states=True,
117
+ ).hidden_states[-1]
118
+ return image_features
119
+ @property
120
+ def dummy_feature(self):
121
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
122
+
123
+ @property
124
+ def dtype(self):
125
+ for p in self.vision_tower.parameters():
126
+ return p.dtype
127
+
128
+ @property
129
+ def device(self):
130
+ for p in self.vision_tower.parameters():
131
+ return p.device
132
+
133
+ @property
134
+ def hidden_size(self):
135
+ return self.config.hidden_size
136
+
137
+ @property
138
+ def num_patches(self):
139
+ return (336 // 14) ** 2
140
+
141
+ @property
142
+ def num_patches_per_side(self):
143
+ #return self.config.image_size // self.config.patch_size
144
+ return 336//14
145
+ #return 27
146
+ # return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
147
+
148
+ @property
149
+ def image_size(self):
150
+ return 384
151
+ #return self.config.image_size
multimodal_encoder/__pycache__/base_encoder.cpython-310.pyc ADDED
Binary file (2.62 kB). View file
 
multimodal_encoder/__pycache__/builder.cpython-310.pyc ADDED
Binary file (697 Bytes). View file
 
multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc ADDED
Binary file (6.53 kB). View file
 
multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc ADDED
Binary file (5.81 kB). View file
 
multimodal_encoder/base_encoder.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class BaseVisionTower(nn.Module):
8
+ def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower_name
14
+ self.delay_load = delay_load
15
+
16
+ @abstractmethod
17
+ def load_model(self, device_map=None):
18
+ raise NotImplementedError("Subclasses must implement load_model")
19
+
20
+ @abstractmethod
21
+ def _forward(self, images):
22
+ raise NotImplementedError("Subclasses must implement forward")
23
+
24
+ def forward(self, images):
25
+ if type(images) is list:
26
+ image_features = [self._forward(image.unsqueeze(0)) for image in images]
27
+ else:
28
+ image_features = self._forward(images)
29
+
30
+ return image_features
31
+
32
+ @property
33
+ def dummy_feature(self):
34
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
35
+
36
+ @property
37
+ def dtype(self):
38
+ # Dynamically infer the dtype from the first parameter, if not explicitly specified
39
+ if hasattr(self.vision_tower, "dtype"):
40
+ return self.vision_tower.dtype
41
+ else:
42
+ params = list(self.vision_tower.parameters())
43
+ return (
44
+ params[0].dtype if len(params) > 0 else torch.float32
45
+ ) # Default to torch.float32 if no parameters
46
+
47
+ @property
48
+ def device(self):
49
+ # Dynamically infer the device from the first parameter, if not explicitly specified
50
+ if hasattr(self.vision_tower, "device"):
51
+ return self.vision_tower.device
52
+ else:
53
+ params = list(self.vision_tower.parameters())
54
+ return (
55
+ params[0].device if len(params) > 0 else torch.device("cpu")
56
+ ) # Default to CPU if no parameters
57
+ @property
58
+ def config(self):
59
+ if self.is_loaded:
60
+ return self.vision_tower.config
61
+ else:
62
+ return self.cfg_only
63
+ @property
64
+ def hidden_size(self):
65
+ try:
66
+ return self.config.hidden_size
67
+ except:
68
+ return self._hidden_size
multimodal_encoder/builder.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .siglip_encoder import SigLipVisionTower
3
+ # from .eva_clip.eva_clip_encoder import EvaClipVisionTower
4
+ # from .dev_eva_clip.eva_vit import EvaViTWrapper
5
+
6
+
7
+ def build_vision_tower(vision_tower_cfg, **kwargs):
8
+
9
+ vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
10
+ is_absolute_path_exists = os.path.exists(vision_tower)
11
+ use_s2 = getattr(vision_tower_cfg, "s2", False)
12
+
13
+ #print(getattr(vision_tower_cfg, "vision_tower", None))
14
+ return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
15
+ if getattr(vision_tower_cfg, "vision_tower", None) and "siglip" in getattr(vision_tower_cfg, "vision_tower", None).lower():
16
+ #print('*************\n')
17
+ return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
18
+
19
+
20
+ raise ValueError(f"Unknown vision tower: {vision_tower}")
multimodal_encoder/siglip_encoder.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from typing import Optional, Tuple, Union, Dict
5
+ from PIL import Image
6
+ from functools import partial, reduce
7
+ from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
8
+
9
+ from .base_encoder import BaseVisionTower
10
+ import torch.distributed as dist
11
+ # --data_path /share/shuyan/video_traindata/anno/\{cinepine_order\}.json \
12
+ # --image_folder /share/shuyan/video_traindata/Bunny-v1_0-data/finetune/images \
13
+ # --video_folder /share/shuyan/video_traindata \
14
+ def rank0_print(*args):
15
+ if dist.is_initialized():
16
+ if dist.get_rank() == 0:
17
+ print(f"Rank {dist.get_rank()}: ", *args)
18
+ else:
19
+ print(*args)
20
+
21
+
22
+ from transformers.image_processing_utils import BatchFeature, get_size_dict
23
+ from transformers.image_transforms import (
24
+ convert_to_rgb,
25
+ normalize,
26
+ rescale,
27
+ resize,
28
+ to_channel_dimension_format,
29
+ )
30
+ from transformers.image_utils import (
31
+ ChannelDimension,
32
+ PILImageResampling,
33
+ to_numpy_array,
34
+ )
35
+ class SigLipImageProcessor:
36
+ def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST):
37
+ crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
38
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
39
+
40
+ self.image_mean = image_mean
41
+ self.image_std = image_std
42
+ self.size = size
43
+ self.resample = resample
44
+ self.rescale_factor = rescale_factor
45
+ self.data_format = data_format
46
+ self.crop_size = crop_size
47
+
48
+ def preprocess(self, images, return_tensors):
49
+ if isinstance(images, Image.Image):
50
+ images = [images]
51
+ else:
52
+ # to adapt video data
53
+ images = [to_numpy_array(image) for image in images]
54
+ assert isinstance(images, list)
55
+
56
+ transforms = [
57
+ convert_to_rgb,
58
+ to_numpy_array,
59
+ partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
60
+ partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
61
+ partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
62
+ partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
63
+ ]
64
+
65
+ images = reduce(lambda x, f: [*map(f, x)], transforms, images)
66
+
67
+ data = {"pixel_values": images}
68
+
69
+ return BatchFeature(data=data, tensor_type=return_tensors)
70
+
71
+ class SigLipVisionTower(BaseVisionTower):
72
+ def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False):
73
+ super(SigLipVisionTower, self).__init__(vision_tower_name, vision_tower_cfg, delay_load)
74
+
75
+ # model_path = "google/siglip-so400m-patch14-384"
76
+ # base_model_name, res, interp = model_path, 384, 576
77
+ # self.vision_tower_name = base_model_name
78
+ self.vision_tower_name, res, interp = vision_tower_name, 384, 576
79
+ self._image_size = res if res is not None else 512
80
+ self.unfreeze_mm_vision_tower = getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False)
81
+
82
+ if not delay_load:
83
+ rank0_print(f"Loading vision tower: {vision_tower_name}")
84
+ self.load_model()
85
+ elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False):
86
+ # TODO: better detector is needed.
87
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
88
+ self.load_model()
89
+ elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts:
90
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
91
+ self.load_model()
92
+ else:
93
+ self.cfg_only = self.config
94
+
95
+ def load_model(self, device_map=None):
96
+ self.vision_model = "siglip"
97
+ # clip_model, processor = create_model_from_pretrained(self.vision_tower_name)
98
+ print(self.vision_tower_name)
99
+ self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
100
+
101
+ # self.vision_tower = clip_model.visual.trunk
102
+ self.vision_tower.output_tokens = True
103
+
104
+ self._hidden_size = self.vision_tower.config.hidden_size
105
+
106
+ self.image_processor = SigLipImageProcessor()
107
+
108
+ del self.vision_tower.vision_model.encoder.layers[-1:]
109
+ self.vision_tower.vision_model.head = nn.Identity()
110
+
111
+ self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
112
+
113
+ self.is_loaded = True
114
+
115
+ def _forward(self, images):
116
+ with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
117
+ image_features = self.vision_tower.forward(
118
+ images.to(device=self.device, dtype=self.dtype),
119
+ output_hidden_states=True,
120
+ ).hidden_states[-1]
121
+ return image_features
122
+ @property
123
+ def dummy_feature(self):
124
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
125
+
126
+ @property
127
+ def dtype(self):
128
+ for p in self.vision_tower.parameters():
129
+ return p.dtype
130
+
131
+ @property
132
+ def device(self):
133
+ for p in self.vision_tower.parameters():
134
+ return p.device
135
+
136
+ @property
137
+ def hidden_size(self):
138
+ return self.config.hidden_size
139
+
140
+ @property
141
+ def num_patches(self):
142
+ return (336 // 14) ** 2
143
+
144
+ @property
145
+ def num_patches_per_side(self):
146
+ #return self.config.image_size // self.config.patch_size
147
+ return 336//14
148
+ #return 27
149
+ # return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
150
+
151
+ @property
152
+ def image_size(self):
153
+ return 384
154
+ #return self.config.image_size
multimodal_projector/__pycache__/builder.cpython-310.pyc ADDED
Binary file (2.4 kB). View file
 
multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc ADDED
Binary file (1.47 kB). View file
 
multimodal_projector/builder.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+
5
+ from .pooler_projector import PoolerProjector
6
+
7
+
8
+ class IdentityMap(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def forward(self, x, *args, **kwargs):
13
+ return x
14
+
15
+ @property
16
+ def config(self):
17
+ return {"mm_projector_type": "identity"}
18
+
19
+
20
+ class SimpleResBlock(nn.Module):
21
+ def __init__(self, channels):
22
+ super().__init__()
23
+ self.pre_norm = nn.LayerNorm(channels)
24
+
25
+ self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
26
+
27
+ def forward(self, x):
28
+ x = self.pre_norm(x)
29
+ return x + self.proj(x)
30
+
31
+
32
+ def build_vision_projector(config, delay_load=False, **kwargs):
33
+ projector_type = getattr(config, "mm_projector_type", "linear")
34
+
35
+ if projector_type == "linear":
36
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
37
+
38
+ if projector_type == "pooler":
39
+ return PoolerProjector(config, kwargs["vision_cfg"])
40
+
41
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
42
+ if mlp_gelu_match:
43
+ mlp_depth = int(mlp_gelu_match.group(1))
44
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
45
+ for _ in range(1, mlp_depth):
46
+ modules.append(nn.GELU())
47
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
48
+ return nn.Sequential(*modules)
49
+
50
+ mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type)
51
+ if mlp_gelu_resnet_match:
52
+ mlp_depth = int(mlp_gelu_resnet_match.group(1))
53
+ res_depth = int(mlp_gelu_resnet_match.group(2))
54
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
55
+ for _ in range(1, mlp_depth):
56
+ modules.append(nn.GELU())
57
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
58
+ for _ in range(res_depth):
59
+ modules.append(SimpleResBlock(config.hidden_size))
60
+ return nn.Sequential(*modules)
61
+
62
+ if projector_type == "identity":
63
+ return IdentityMap()
64
+
65
+ raise ValueError(f"Unknown projector type: {projector_type}")
multimodal_projector/pooler_projector.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import math
5
+
6
+ from transformers.models.clip.modeling_clip import CLIPVisionModel
7
+
8
+
9
+ class PoolerProjector(nn.Module):
10
+ def __init__(self, config, vision_cfg):
11
+ super().__init__()
12
+ self._config = config
13
+ self.hw = vision_cfg.image_size // vision_cfg.patch_size
14
+
15
+ self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2)
16
+
17
+ self.proj = nn.Sequential(
18
+ nn.GELU(),
19
+ nn.Linear(config.hidden_size, config.hidden_size),
20
+ )
21
+
22
+ def forward(self, x, *args, **kwargs):
23
+ height = width = self.hw
24
+ assert height * width == x.shape[1]
25
+ x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2)
26
+ x = self.conv_pool(x)
27
+ x = x.flatten(2).transpose(1, 2)
28
+ x = self.proj(x)
29
+ return x
30
+
31
+ @property
32
+ def config(self):
33
+ return {"mm_projector_type": "pooler"}
multimodal_resampler/__pycache__/builder.cpython-310.pyc ADDED
Binary file (1.45 kB). View file
 
multimodal_resampler/__pycache__/masked_drop.cpython-310.pyc ADDED
Binary file (2.47 kB). View file
 
multimodal_resampler/__pycache__/perceiver.cpython-310.pyc ADDED
Binary file (4.86 kB). View file
 
multimodal_resampler/__pycache__/qformer.cpython-310.pyc ADDED
Binary file (32.7 kB). View file
 
multimodal_resampler/__pycache__/spatial_pool.cpython-310.pyc ADDED
Binary file (1.9 kB). View file
 
multimodal_resampler/builder.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .masked_drop import MaskedDrop
4
+ from .spatial_pool import SpatialPool
5
+ from .perceiver import PerceiverResampler
6
+ from .qformer import Qformer
7
+
8
+
9
+ class IdentityMap(torch.nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ def forward(self, x, *args, **kwargs):
14
+ return x
15
+
16
+ @property
17
+ def config(self):
18
+ return {"mm_resampler_type": None}
19
+
20
+
21
+ def build_vision_resampler(model_args, delay_load=False, **kwargs):
22
+ resampler_type = getattr(model_args, "mm_resampler_type", None)
23
+ if resampler_type == "masked_drop":
24
+ return MaskedDrop(model_args)
25
+ elif resampler_type == "spatial_pool":
26
+ return SpatialPool(model_args, **kwargs)
27
+ elif resampler_type == "perceiver":
28
+ return PerceiverResampler(model_args, **kwargs)
29
+ elif resampler_type == "qformer":
30
+ return Qformer(model_args, **kwargs)
31
+ elif resampler_type is None:
32
+ return IdentityMap()
33
+
34
+ raise ValueError(f"Unknown resampler type: {resampler_type}")
multimodal_resampler/masked_drop.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import random
5
+
6
+
7
+ class MaskedDrop(nn.Module):
8
+ def __init__(self, model_args):
9
+ super().__init__()
10
+
11
+ self.mode = model_args.mm_mask_drop_mode
12
+ self.skip_percentage = model_args.mm_mask_drop_skip_percentage
13
+ self.ratio = model_args.mm_mask_drop_ratio
14
+ self.ratio_upper = model_args.mm_mask_drop_ratio_upper
15
+ self.ratio_lower = model_args.mm_mask_drop_ratio_lower
16
+
17
+ def forward(self, image_features, *args, **kwargs):
18
+
19
+ if not self.training:
20
+ return image_features
21
+
22
+ if self.skip_percentage > random.random():
23
+ return image_features
24
+
25
+ masked_features = []
26
+
27
+ for image_feature in image_features:
28
+ num_tokens = image_feature.shape[0]
29
+ if self.mode == "fixed":
30
+ num_keep = int(num_tokens * self.ratio)
31
+ masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0])
32
+ elif self.mode == "range":
33
+ num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper))
34
+ masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0])
35
+ elif self.mode == "cls_only":
36
+ masked_features.append(image_feature[0:1])
37
+ else:
38
+ raise ValueError(f"Unexpected masked drop mode: {self.mode}")
39
+
40
+ if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]):
41
+ masked_features = torch.stack(masked_features, dim=0)
42
+
43
+ return masked_features
44
+
45
+ @property
46
+ def config(self):
47
+ return {
48
+ "mm_resampler_type": "masked_drop",
49
+ "mm_mask_drop_mode": self.mode,
50
+ "mm_mask_drop_skip_percentage": self.skip_percentage,
51
+ "mm_mask_drop_ratio": self.ratio,
52
+ "mm_mask_drop_ratio_upper": self.ratio_upper,
53
+ "mm_mask_drop_ratio_lower": self.ratio_lower,
54
+ }
55
+
56
+ def random_masking(self, x, len_keep):
57
+ """
58
+ Perform per-sample random masking by per-sample shuffling.
59
+ Per-sample shuffling is done by argsort random noise.
60
+ x: [N, L, D], sequence
61
+ """
62
+ N, L, D = x.shape # batch, length, dim
63
+
64
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
65
+
66
+ # sort noise for each sample
67
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
68
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
69
+
70
+ # keep the first subset
71
+ ids_keep = ids_shuffle[:, :len_keep]
72
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
73
+
74
+ # generate the binary mask: 0 is keep, 1 is remove
75
+ mask = torch.ones([N, L], device=x.device)
76
+ mask[:, :len_keep] = 0
77
+ # unshuffle to get the binary mask
78
+ mask = torch.gather(mask, dim=1, index=ids_restore)
79
+
80
+ return x_masked, mask, ids_restore
multimodal_resampler/perceiver.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from https://github.com/lucidrains/flamingo-pytorch
3
+ """
4
+
5
+ import torch
6
+ from einops import rearrange, repeat
7
+
8
+ try:
9
+ from einops_exts import rearrange_many
10
+ except:
11
+ pass
12
+
13
+ from torch import einsum, nn
14
+
15
+
16
+ def exists(val):
17
+ return val is not None
18
+
19
+
20
+ def FeedForward(dim, mult=4):
21
+ inner_dim = int(dim * mult)
22
+ return nn.Sequential(
23
+ nn.LayerNorm(dim),
24
+ nn.Linear(dim, inner_dim, bias=False),
25
+ nn.GELU(),
26
+ nn.Linear(inner_dim, dim, bias=False),
27
+ )
28
+
29
+
30
+ class PerceiverAttention(nn.Module):
31
+ def __init__(self, *, dim, dim_head=64, heads=8):
32
+ super().__init__()
33
+ self.scale = dim_head**-0.5
34
+ self.heads = heads
35
+ inner_dim = dim_head * heads
36
+
37
+ self.norm_media = nn.LayerNorm(dim)
38
+ self.norm_latents = nn.LayerNorm(dim)
39
+
40
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
41
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
42
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
43
+
44
+ def forward(self, x, latents):
45
+ """
46
+ Args:
47
+ x (torch.Tensor): image features
48
+ shape (b, T, n1, D)
49
+ latent (torch.Tensor): latent features
50
+ shape (b, T, n2, D)
51
+ """
52
+ x = self.norm_media(x)
53
+ latents = self.norm_latents(latents)
54
+
55
+ h = self.heads
56
+
57
+ q = self.to_q(latents)
58
+ kv_input = torch.cat((x, latents), dim=-2)
59
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
60
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
61
+ q = q * self.scale
62
+
63
+ # attention
64
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
65
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
66
+ attn = sim.softmax(dim=-1)
67
+
68
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
69
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
70
+ return self.to_out(out)
71
+
72
+
73
+ class PerceiverResamplerModule(nn.Module):
74
+ def __init__(
75
+ self,
76
+ *,
77
+ dim,
78
+ depth=6,
79
+ dim_head=64,
80
+ heads=8,
81
+ num_latents=64,
82
+ max_num_media=None,
83
+ max_num_frames=None,
84
+ ff_mult=4,
85
+ ):
86
+ super().__init__()
87
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
88
+ self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None
89
+ self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None
90
+
91
+ self.layers = nn.ModuleList([])
92
+ for _ in range(depth):
93
+ self.layers.append(
94
+ nn.ModuleList(
95
+ [
96
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
97
+ FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(),
98
+ ]
99
+ )
100
+ )
101
+
102
+ self.norm = nn.LayerNorm(dim)
103
+
104
+ def forward(self, x):
105
+ """
106
+ Args:
107
+ x (torch.Tensor): image features
108
+ shape (b, T, F, v, D)
109
+ Returns:
110
+ shape (b, T, n, D) where n is self.num_latents
111
+ """
112
+ b, T, F, v = x.shape[:4]
113
+
114
+ # frame and media time embeddings
115
+ if exists(self.frame_embs):
116
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
117
+ x = x + frame_embs
118
+ x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions
119
+ if exists(self.media_time_embs):
120
+ x = x + self.media_time_embs[:T]
121
+
122
+ # blocks
123
+ latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
124
+ for attn, ff in self.layers:
125
+ latents = attn(x, latents) + latents
126
+ latents = ff(latents) + latents
127
+ return self.norm(latents)
128
+
129
+
130
+ class PerceiverResampler(nn.Module):
131
+ def __init__(self, model_args, vision_tower):
132
+ super().__init__()
133
+
134
+ self.depth = model_args.mm_perceiver_depth
135
+ self.num_latents = model_args.mm_perceiver_latents
136
+ self.ff_mult = model_args.mm_perceiver_ff_mult
137
+ self.pretrained = model_args.mm_perceiver_pretrained
138
+
139
+ self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult)
140
+
141
+ if self.pretrained is not None:
142
+ self.load_state_dict(torch.load(self.pretrained))
143
+
144
+ def forward(self, image_features, *args, **kwargs):
145
+ return self.perceiver(image_features[:, None, None]).squeeze(1)
146
+
147
+ @property
148
+ def config(self):
149
+ return {
150
+ "mm_resampler_type": "perceiver",
151
+ "mm_perceiver_depth": self.depth,
152
+ "mm_perceiver_latents": self.num_latents,
153
+ "mm_perceiver_ff_mult": self.ff_mult,
154
+ "mm_perceiver_pretrained": self.pretrained,
155
+ }
multimodal_resampler/qformer.py ADDED
@@ -0,0 +1,1160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Copyright (c) 2023, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ """
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Dict, Any
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ def disabled_train(self, mode=True):
52
+ """Overwrite model.train with this function to make sure train/eval mode
53
+ does not change anymore."""
54
+ return self
55
+
56
+
57
+ class BertEmbeddings(nn.Module):
58
+ """Construct the embeddings from word and position embeddings."""
59
+
60
+ def __init__(self, config):
61
+ super().__init__()
62
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
63
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
64
+
65
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
66
+ # any TensorFlow checkpoint file
67
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
68
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
69
+
70
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
71
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
72
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
73
+
74
+ self.config = config
75
+
76
+ def forward(
77
+ self,
78
+ input_ids=None,
79
+ position_ids=None,
80
+ query_embeds=None,
81
+ past_key_values_length=0,
82
+ ):
83
+ if input_ids is not None:
84
+ seq_length = input_ids.size()[1]
85
+ else:
86
+ seq_length = 0
87
+
88
+ if position_ids is None:
89
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
90
+
91
+ if input_ids is not None:
92
+ embeddings = self.word_embeddings(input_ids)
93
+ if self.position_embedding_type == "absolute":
94
+ position_embeddings = self.position_embeddings(position_ids)
95
+ embeddings = embeddings + position_embeddings
96
+
97
+ if query_embeds is not None:
98
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
99
+ else:
100
+ embeddings = query_embeds
101
+
102
+ embeddings = self.LayerNorm(embeddings)
103
+ embeddings = self.dropout(embeddings)
104
+ return embeddings
105
+
106
+
107
+ class BertSelfAttention(nn.Module):
108
+ def __init__(self, config, is_cross_attention):
109
+ super().__init__()
110
+ self.config = config
111
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
112
+ raise ValueError("The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads))
113
+
114
+ self.num_attention_heads = config.num_attention_heads
115
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
116
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
117
+
118
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
119
+ if is_cross_attention:
120
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
121
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
122
+ else:
123
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
124
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
125
+
126
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
127
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
128
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
129
+ self.max_position_embeddings = config.max_position_embeddings
130
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
131
+ self.save_attention = False
132
+
133
+ def save_attn_gradients(self, attn_gradients):
134
+ self.attn_gradients = attn_gradients
135
+
136
+ def get_attn_gradients(self):
137
+ return self.attn_gradients
138
+
139
+ def save_attention_map(self, attention_map):
140
+ self.attention_map = attention_map
141
+
142
+ def get_attention_map(self):
143
+ return self.attention_map
144
+
145
+ def transpose_for_scores(self, x):
146
+ new_x_shape = x.size()[:-1] + (
147
+ self.num_attention_heads,
148
+ self.attention_head_size,
149
+ )
150
+ x = x.view(*new_x_shape)
151
+ return x.permute(0, 2, 1, 3)
152
+
153
+ def forward(
154
+ self,
155
+ hidden_states,
156
+ attention_mask=None,
157
+ head_mask=None,
158
+ encoder_hidden_states=None,
159
+ encoder_attention_mask=None,
160
+ past_key_value=None,
161
+ output_attentions=False,
162
+ ):
163
+
164
+ # If this is instantiated as a cross-attention module, the keys
165
+ # and values come from an encoder; the attention mask needs to be
166
+ # such that the encoder's padding tokens are not attended to.
167
+ is_cross_attention = encoder_hidden_states is not None
168
+
169
+ if is_cross_attention:
170
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
171
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
172
+ attention_mask = encoder_attention_mask
173
+ elif past_key_value is not None:
174
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
175
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
176
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
177
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
178
+ else:
179
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
180
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
181
+
182
+ mixed_query_layer = self.query(hidden_states)
183
+
184
+ query_layer = self.transpose_for_scores(mixed_query_layer)
185
+
186
+ past_key_value = (key_layer, value_layer)
187
+
188
+ # Take the dot product between "query" and "key" to get the raw attention scores.
189
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
190
+
191
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
192
+ seq_length = hidden_states.size()[1]
193
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
194
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
195
+ distance = position_ids_l - position_ids_r
196
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
197
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
198
+
199
+ if self.position_embedding_type == "relative_key":
200
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
201
+ attention_scores = attention_scores + relative_position_scores
202
+ elif self.position_embedding_type == "relative_key_query":
203
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
204
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
205
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
206
+
207
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
208
+ if attention_mask is not None:
209
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
210
+ attention_scores = attention_scores + attention_mask
211
+
212
+ # Normalize the attention scores to probabilities.
213
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
214
+
215
+ if is_cross_attention and self.save_attention:
216
+ self.save_attention_map(attention_probs)
217
+ attention_probs.register_hook(self.save_attn_gradients)
218
+
219
+ # This is actually dropping out entire tokens to attend to, which might
220
+ # seem a bit unusual, but is taken from the original Transformer paper.
221
+ attention_probs_dropped = self.dropout(attention_probs)
222
+
223
+ # Mask heads if we want to
224
+ if head_mask is not None:
225
+ attention_probs_dropped = attention_probs_dropped * head_mask
226
+
227
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
228
+
229
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
230
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
231
+ context_layer = context_layer.view(*new_context_layer_shape)
232
+
233
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
234
+
235
+ outputs = outputs + (past_key_value,)
236
+ return outputs
237
+
238
+
239
+ class BertSelfOutput(nn.Module):
240
+ def __init__(self, config):
241
+ super().__init__()
242
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
243
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
244
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
245
+
246
+ def forward(self, hidden_states, input_tensor):
247
+ hidden_states = self.dense(hidden_states)
248
+ hidden_states = self.dropout(hidden_states)
249
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
250
+ return hidden_states
251
+
252
+
253
+ class BertAttention(nn.Module):
254
+ def __init__(self, config, is_cross_attention=False):
255
+ super().__init__()
256
+ self.self = BertSelfAttention(config, is_cross_attention)
257
+ self.output = BertSelfOutput(config)
258
+ self.pruned_heads = set()
259
+
260
+ def prune_heads(self, heads):
261
+ if len(heads) == 0:
262
+ return
263
+ heads, index = find_pruneable_heads_and_indices(
264
+ heads,
265
+ self.self.num_attention_heads,
266
+ self.self.attention_head_size,
267
+ self.pruned_heads,
268
+ )
269
+
270
+ # Prune linear layers
271
+ self.self.query = prune_linear_layer(self.self.query, index)
272
+ self.self.key = prune_linear_layer(self.self.key, index)
273
+ self.self.value = prune_linear_layer(self.self.value, index)
274
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
275
+
276
+ # Update hyper params and store pruned heads
277
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
278
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
279
+ self.pruned_heads = self.pruned_heads.union(heads)
280
+
281
+ def forward(
282
+ self,
283
+ hidden_states,
284
+ attention_mask=None,
285
+ head_mask=None,
286
+ encoder_hidden_states=None,
287
+ encoder_attention_mask=None,
288
+ past_key_value=None,
289
+ output_attentions=False,
290
+ ):
291
+ self_outputs = self.self(
292
+ hidden_states,
293
+ attention_mask,
294
+ head_mask,
295
+ encoder_hidden_states,
296
+ encoder_attention_mask,
297
+ past_key_value,
298
+ output_attentions,
299
+ )
300
+ attention_output = self.output(self_outputs[0], hidden_states)
301
+
302
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
303
+ return outputs
304
+
305
+
306
+ class BertIntermediate(nn.Module):
307
+ def __init__(self, config):
308
+ super().__init__()
309
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
310
+ if isinstance(config.hidden_act, str):
311
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
312
+ else:
313
+ self.intermediate_act_fn = config.hidden_act
314
+
315
+ def forward(self, hidden_states):
316
+ hidden_states = self.dense(hidden_states)
317
+ hidden_states = self.intermediate_act_fn(hidden_states)
318
+ return hidden_states
319
+
320
+
321
+ class BertOutput(nn.Module):
322
+ def __init__(self, config):
323
+ super().__init__()
324
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
325
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
326
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
327
+
328
+ def forward(self, hidden_states, input_tensor):
329
+ hidden_states = self.dense(hidden_states)
330
+ hidden_states = self.dropout(hidden_states)
331
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
332
+ return hidden_states
333
+
334
+
335
+ class BertLayer(nn.Module):
336
+ def __init__(self, config, layer_num):
337
+ super().__init__()
338
+ self.config = config
339
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
340
+ self.seq_len_dim = 1
341
+ self.attention = BertAttention(config)
342
+ self.layer_num = layer_num
343
+ if self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0:
344
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
345
+ self.has_cross_attention = True
346
+ else:
347
+ self.has_cross_attention = False
348
+ self.intermediate = BertIntermediate(config)
349
+ self.output = BertOutput(config)
350
+
351
+ self.intermediate_query = BertIntermediate(config)
352
+ self.output_query = BertOutput(config)
353
+
354
+ def forward(
355
+ self,
356
+ hidden_states,
357
+ attention_mask=None,
358
+ head_mask=None,
359
+ encoder_hidden_states=None,
360
+ encoder_attention_mask=None,
361
+ past_key_value=None,
362
+ output_attentions=False,
363
+ query_length=0,
364
+ ):
365
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
366
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
367
+ self_attention_outputs = self.attention(
368
+ hidden_states,
369
+ attention_mask,
370
+ head_mask,
371
+ output_attentions=output_attentions,
372
+ past_key_value=self_attn_past_key_value,
373
+ )
374
+ attention_output = self_attention_outputs[0]
375
+ outputs = self_attention_outputs[1:-1]
376
+
377
+ present_key_value = self_attention_outputs[-1]
378
+
379
+ if query_length > 0:
380
+ query_attention_output = attention_output[:, :query_length, :]
381
+
382
+ if self.has_cross_attention:
383
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
384
+ cross_attention_outputs = self.crossattention(
385
+ query_attention_output,
386
+ attention_mask,
387
+ head_mask,
388
+ encoder_hidden_states,
389
+ encoder_attention_mask,
390
+ output_attentions=output_attentions,
391
+ )
392
+ query_attention_output = cross_attention_outputs[0]
393
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
394
+
395
+ layer_output = apply_chunking_to_forward(
396
+ self.feed_forward_chunk_query,
397
+ self.chunk_size_feed_forward,
398
+ self.seq_len_dim,
399
+ query_attention_output,
400
+ )
401
+ if attention_output.shape[1] > query_length:
402
+ layer_output_text = apply_chunking_to_forward(
403
+ self.feed_forward_chunk,
404
+ self.chunk_size_feed_forward,
405
+ self.seq_len_dim,
406
+ attention_output[:, query_length:, :],
407
+ )
408
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
409
+ else:
410
+ layer_output = apply_chunking_to_forward(
411
+ self.feed_forward_chunk,
412
+ self.chunk_size_feed_forward,
413
+ self.seq_len_dim,
414
+ attention_output,
415
+ )
416
+ outputs = (layer_output,) + outputs
417
+
418
+ outputs = outputs + (present_key_value,)
419
+
420
+ return outputs
421
+
422
+ def feed_forward_chunk(self, attention_output):
423
+ intermediate_output = self.intermediate(attention_output)
424
+ layer_output = self.output(intermediate_output, attention_output)
425
+ return layer_output
426
+
427
+ def feed_forward_chunk_query(self, attention_output):
428
+ intermediate_output = self.intermediate_query(attention_output)
429
+ layer_output = self.output_query(intermediate_output, attention_output)
430
+ return layer_output
431
+
432
+
433
+ class BertEncoder(nn.Module):
434
+ def __init__(self, config):
435
+ super().__init__()
436
+ self.config = config
437
+ self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)])
438
+
439
+ def forward(
440
+ self,
441
+ hidden_states,
442
+ attention_mask=None,
443
+ head_mask=None,
444
+ encoder_hidden_states=None,
445
+ encoder_attention_mask=None,
446
+ past_key_values=None,
447
+ use_cache=None,
448
+ output_attentions=False,
449
+ output_hidden_states=False,
450
+ return_dict=True,
451
+ query_length=0,
452
+ ):
453
+ all_hidden_states = () if output_hidden_states else None
454
+ all_self_attentions = () if output_attentions else None
455
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
456
+
457
+ next_decoder_cache = () if use_cache else None
458
+
459
+ for i in range(self.config.num_hidden_layers):
460
+ layer_module = self.layer[i]
461
+ if output_hidden_states:
462
+ all_hidden_states = all_hidden_states + (hidden_states,)
463
+
464
+ layer_head_mask = head_mask[i] if head_mask is not None else None
465
+ past_key_value = past_key_values[i] if past_key_values is not None else None
466
+
467
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
468
+
469
+ if use_cache:
470
+ logger.warn("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
471
+ use_cache = False
472
+
473
+ def create_custom_forward(module):
474
+ def custom_forward(*inputs):
475
+ return module(*inputs, past_key_value, output_attentions, query_length)
476
+
477
+ return custom_forward
478
+
479
+ layer_outputs = torch.utils.checkpoint.checkpoint(
480
+ create_custom_forward(layer_module),
481
+ hidden_states,
482
+ attention_mask,
483
+ layer_head_mask,
484
+ encoder_hidden_states,
485
+ encoder_attention_mask,
486
+ )
487
+ else:
488
+ layer_outputs = layer_module(
489
+ hidden_states,
490
+ attention_mask,
491
+ layer_head_mask,
492
+ encoder_hidden_states,
493
+ encoder_attention_mask,
494
+ past_key_value,
495
+ output_attentions,
496
+ query_length,
497
+ )
498
+
499
+ hidden_states = layer_outputs[0]
500
+ if use_cache:
501
+ next_decoder_cache += (layer_outputs[-1],)
502
+ if output_attentions:
503
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
504
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
505
+
506
+ if output_hidden_states:
507
+ all_hidden_states = all_hidden_states + (hidden_states,)
508
+
509
+ if not return_dict:
510
+ return tuple(
511
+ v
512
+ for v in [
513
+ hidden_states,
514
+ next_decoder_cache,
515
+ all_hidden_states,
516
+ all_self_attentions,
517
+ all_cross_attentions,
518
+ ]
519
+ if v is not None
520
+ )
521
+ return BaseModelOutputWithPastAndCrossAttentions(
522
+ last_hidden_state=hidden_states,
523
+ past_key_values=next_decoder_cache,
524
+ hidden_states=all_hidden_states,
525
+ attentions=all_self_attentions,
526
+ cross_attentions=all_cross_attentions,
527
+ )
528
+
529
+
530
+ class BertPooler(nn.Module):
531
+ def __init__(self, config):
532
+ super().__init__()
533
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
534
+ self.activation = nn.Tanh()
535
+
536
+ def forward(self, hidden_states):
537
+ # We "pool" the model by simply taking the hidden state corresponding
538
+ # to the first token.
539
+ first_token_tensor = hidden_states[:, 0]
540
+ pooled_output = self.dense(first_token_tensor)
541
+ pooled_output = self.activation(pooled_output)
542
+ return pooled_output
543
+
544
+
545
+ class BertPredictionHeadTransform(nn.Module):
546
+ def __init__(self, config):
547
+ super().__init__()
548
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
549
+ if isinstance(config.hidden_act, str):
550
+ self.transform_act_fn = ACT2FN[config.hidden_act]
551
+ else:
552
+ self.transform_act_fn = config.hidden_act
553
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
554
+
555
+ def forward(self, hidden_states):
556
+ hidden_states = self.dense(hidden_states)
557
+ hidden_states = self.transform_act_fn(hidden_states)
558
+ hidden_states = self.LayerNorm(hidden_states)
559
+ return hidden_states
560
+
561
+
562
+ class BertLMPredictionHead(nn.Module):
563
+ def __init__(self, config):
564
+ super().__init__()
565
+ self.transform = BertPredictionHeadTransform(config)
566
+
567
+ # The output weights are the same as the input embeddings, but there is
568
+ # an output-only bias for each token.
569
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
570
+
571
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
572
+
573
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
574
+ self.decoder.bias = self.bias
575
+
576
+ def forward(self, hidden_states):
577
+ hidden_states = self.transform(hidden_states)
578
+ hidden_states = self.decoder(hidden_states)
579
+ return hidden_states
580
+
581
+
582
+ class BertOnlyMLMHead(nn.Module):
583
+ def __init__(self, config):
584
+ super().__init__()
585
+ self.predictions = BertLMPredictionHead(config)
586
+
587
+ def forward(self, sequence_output):
588
+ prediction_scores = self.predictions(sequence_output)
589
+ return prediction_scores
590
+
591
+
592
+ class BertPreTrainedModel(PreTrainedModel):
593
+ """
594
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
595
+ models.
596
+ """
597
+
598
+ config_class = BertConfig
599
+ base_model_prefix = "bert"
600
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
601
+
602
+ def _init_weights(self, module):
603
+ """Initialize the weights"""
604
+ if isinstance(module, (nn.Linear, nn.Embedding)):
605
+ # Slightly different from the TF version which uses truncated_normal for initialization
606
+ # cf https://github.com/pytorch/pytorch/pull/5617
607
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
608
+ elif isinstance(module, nn.LayerNorm):
609
+ module.bias.data.zero_()
610
+ module.weight.data.fill_(1.0)
611
+ if isinstance(module, nn.Linear) and module.bias is not None:
612
+ module.bias.data.zero_()
613
+
614
+
615
+ class BertModel(BertPreTrainedModel):
616
+ """
617
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
618
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
619
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
620
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
621
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
622
+ input to the forward pass.
623
+ """
624
+
625
+ def __init__(self, config, add_pooling_layer=False):
626
+ super().__init__(config)
627
+ self.config = config
628
+
629
+ self.embeddings = BertEmbeddings(config)
630
+
631
+ self.encoder = BertEncoder(config)
632
+
633
+ self.pooler = BertPooler(config) if add_pooling_layer else None
634
+
635
+ self.init_weights()
636
+
637
+ def get_input_embeddings(self):
638
+ return self.embeddings.word_embeddings
639
+
640
+ def set_input_embeddings(self, value):
641
+ self.embeddings.word_embeddings = value
642
+
643
+ def _prune_heads(self, heads_to_prune):
644
+ """
645
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
646
+ class PreTrainedModel
647
+ """
648
+ for layer, heads in heads_to_prune.items():
649
+ self.encoder.layer[layer].attention.prune_heads(heads)
650
+
651
+ def get_extended_attention_mask(
652
+ self,
653
+ attention_mask: Tensor,
654
+ input_shape: Tuple[int],
655
+ device: device,
656
+ is_decoder: bool,
657
+ has_query: bool = False,
658
+ ) -> Tensor:
659
+ """
660
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
661
+
662
+ Arguments:
663
+ attention_mask (:obj:`torch.Tensor`):
664
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
665
+ input_shape (:obj:`Tuple[int]`):
666
+ The shape of the input to the model.
667
+ device: (:obj:`torch.device`):
668
+ The device of the input to the model.
669
+
670
+ Returns:
671
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
672
+ """
673
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
674
+ # ourselves in which case we just need to make it broadcastable to all heads.
675
+ if attention_mask.dim() == 3:
676
+ extended_attention_mask = attention_mask[:, None, :, :]
677
+ elif attention_mask.dim() == 2:
678
+ # Provided a padding mask of dimensions [batch_size, seq_length]
679
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
680
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
681
+ if is_decoder:
682
+ batch_size, seq_length = input_shape
683
+
684
+ seq_ids = torch.arange(seq_length, device=device)
685
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
686
+
687
+ # add a prefix ones mask to the causal mask
688
+ # causal and attention masks must have same type with pytorch version < 1.3
689
+ causal_mask = causal_mask.to(attention_mask.dtype)
690
+
691
+ if causal_mask.shape[1] < attention_mask.shape[1]:
692
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
693
+ if has_query: # UniLM style attention mask
694
+ causal_mask = torch.cat(
695
+ [
696
+ torch.zeros(
697
+ (batch_size, prefix_seq_len, seq_length),
698
+ device=device,
699
+ dtype=causal_mask.dtype,
700
+ ),
701
+ causal_mask,
702
+ ],
703
+ axis=1,
704
+ )
705
+ causal_mask = torch.cat(
706
+ [
707
+ torch.ones(
708
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
709
+ device=device,
710
+ dtype=causal_mask.dtype,
711
+ ),
712
+ causal_mask,
713
+ ],
714
+ axis=-1,
715
+ )
716
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
717
+ else:
718
+ extended_attention_mask = attention_mask[:, None, None, :]
719
+ else:
720
+ raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape))
721
+
722
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
723
+ # masked positions, this operation will create a tensor which is 0.0 for
724
+ # positions we want to attend and -10000.0 for masked positions.
725
+ # Since we are adding it to the raw scores before the softmax, this is
726
+ # effectively the same as removing these entirely.
727
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
728
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
729
+ return extended_attention_mask
730
+
731
+ def forward(
732
+ self,
733
+ input_ids=None,
734
+ attention_mask=None,
735
+ position_ids=None,
736
+ head_mask=None,
737
+ query_embeds=None,
738
+ encoder_hidden_states=None,
739
+ encoder_attention_mask=None,
740
+ past_key_values=None,
741
+ use_cache=None,
742
+ output_attentions=None,
743
+ output_hidden_states=None,
744
+ return_dict=None,
745
+ is_decoder=False,
746
+ ):
747
+ r"""
748
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
749
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
750
+ the model is configured as a decoder.
751
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
752
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
753
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
754
+ - 1 for tokens that are **not masked**,
755
+ - 0 for tokens that are **masked**.
756
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
757
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
758
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
759
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
760
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
761
+ use_cache (:obj:`bool`, `optional`):
762
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
763
+ decoding (see :obj:`past_key_values`).
764
+ """
765
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
766
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
767
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
768
+
769
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
770
+
771
+ if input_ids is None:
772
+ assert query_embeds is not None, "You have to specify query_embeds when input_ids is None"
773
+
774
+ # past_key_values_length
775
+ past_key_values_length = past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
776
+
777
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
778
+
779
+ embedding_output = self.embeddings(
780
+ input_ids=input_ids,
781
+ position_ids=position_ids,
782
+ query_embeds=query_embeds,
783
+ past_key_values_length=past_key_values_length,
784
+ )
785
+
786
+ input_shape = embedding_output.size()[:-1]
787
+ batch_size, seq_length = input_shape
788
+ device = embedding_output.device
789
+
790
+ if attention_mask is None:
791
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
792
+
793
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
794
+ # ourselves in which case we just need to make it broadcastable to all heads.
795
+ if is_decoder:
796
+ extended_attention_mask = self.get_extended_attention_mask(
797
+ attention_mask,
798
+ input_ids.shape,
799
+ device,
800
+ is_decoder,
801
+ has_query=(query_embeds is not None),
802
+ )
803
+ else:
804
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device, is_decoder)
805
+
806
+ # If a 2D or 3D attention mask is provided for the cross-attention
807
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
808
+ if encoder_hidden_states is not None:
809
+ if type(encoder_hidden_states) == list:
810
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
811
+ else:
812
+ (
813
+ encoder_batch_size,
814
+ encoder_sequence_length,
815
+ _,
816
+ ) = encoder_hidden_states.size()
817
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
818
+
819
+ if type(encoder_attention_mask) == list:
820
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
821
+ elif encoder_attention_mask is None:
822
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
823
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
824
+ else:
825
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
826
+ else:
827
+ encoder_extended_attention_mask = None
828
+
829
+ # Prepare head mask if needed
830
+ # 1.0 in head_mask indicate we keep the head
831
+ # attention_probs has shape bsz x n_heads x N x N
832
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
833
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
834
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
835
+
836
+ encoder_outputs = self.encoder(
837
+ embedding_output,
838
+ attention_mask=extended_attention_mask,
839
+ head_mask=head_mask,
840
+ encoder_hidden_states=encoder_hidden_states,
841
+ encoder_attention_mask=encoder_extended_attention_mask,
842
+ past_key_values=past_key_values,
843
+ use_cache=use_cache,
844
+ output_attentions=output_attentions,
845
+ output_hidden_states=output_hidden_states,
846
+ return_dict=return_dict,
847
+ query_length=query_length,
848
+ )
849
+ sequence_output = encoder_outputs[0]
850
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
851
+
852
+ if not return_dict:
853
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
854
+
855
+ return BaseModelOutputWithPoolingAndCrossAttentions(
856
+ last_hidden_state=sequence_output,
857
+ pooler_output=pooled_output,
858
+ past_key_values=encoder_outputs.past_key_values,
859
+ hidden_states=encoder_outputs.hidden_states,
860
+ attentions=encoder_outputs.attentions,
861
+ cross_attentions=encoder_outputs.cross_attentions,
862
+ )
863
+
864
+
865
+ class BertLMHeadModel(BertPreTrainedModel):
866
+
867
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
868
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
869
+
870
+ def __init__(self, config):
871
+ super().__init__(config)
872
+
873
+ self.bert = BertModel(config, add_pooling_layer=False)
874
+ self.cls = BertOnlyMLMHead(config)
875
+
876
+ self.init_weights()
877
+
878
+ def get_output_embeddings(self):
879
+ return self.cls.predictions.decoder
880
+
881
+ def set_output_embeddings(self, new_embeddings):
882
+ self.cls.predictions.decoder = new_embeddings
883
+
884
+ def forward(
885
+ self,
886
+ input_ids=None,
887
+ attention_mask=None,
888
+ position_ids=None,
889
+ head_mask=None,
890
+ query_embeds=None,
891
+ encoder_hidden_states=None,
892
+ encoder_attention_mask=None,
893
+ labels=None,
894
+ past_key_values=None,
895
+ use_cache=True,
896
+ output_attentions=None,
897
+ output_hidden_states=None,
898
+ return_dict=None,
899
+ return_logits=False,
900
+ is_decoder=True,
901
+ reduction="mean",
902
+ ):
903
+ r"""
904
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
905
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
906
+ the model is configured as a decoder.
907
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
908
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
909
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
910
+ - 1 for tokens that are **not masked**,
911
+ - 0 for tokens that are **masked**.
912
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
913
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
914
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
915
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
916
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
917
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
918
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
919
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
920
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
921
+ use_cache (:obj:`bool`, `optional`):
922
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
923
+ decoding (see :obj:`past_key_values`).
924
+ Returns:
925
+ Example::
926
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
927
+ >>> import torch
928
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
929
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
930
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
931
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
932
+ >>> outputs = model(**inputs)
933
+ >>> prediction_logits = outputs.logits
934
+ """
935
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
936
+ if labels is not None:
937
+ use_cache = False
938
+ if past_key_values is not None:
939
+ query_embeds = None
940
+
941
+ outputs = self.bert(
942
+ input_ids,
943
+ attention_mask=attention_mask,
944
+ position_ids=position_ids,
945
+ head_mask=head_mask,
946
+ query_embeds=query_embeds,
947
+ encoder_hidden_states=encoder_hidden_states,
948
+ encoder_attention_mask=encoder_attention_mask,
949
+ past_key_values=past_key_values,
950
+ use_cache=use_cache,
951
+ output_attentions=output_attentions,
952
+ output_hidden_states=output_hidden_states,
953
+ return_dict=return_dict,
954
+ is_decoder=is_decoder,
955
+ )
956
+
957
+ sequence_output = outputs[0]
958
+ if query_embeds is not None:
959
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
960
+
961
+ prediction_scores = self.cls(sequence_output)
962
+
963
+ if return_logits:
964
+ return prediction_scores[:, :-1, :].contiguous()
965
+
966
+ lm_loss = None
967
+ if labels is not None:
968
+ # we are doing next-token prediction; shift prediction scores and input ids by one
969
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
970
+ labels = labels[:, 1:].contiguous()
971
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
972
+ lm_loss = loss_fct(
973
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
974
+ labels.view(-1),
975
+ )
976
+ if reduction == "none":
977
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
978
+
979
+ if not return_dict:
980
+ output = (prediction_scores,) + outputs[2:]
981
+ return ((lm_loss,) + output) if lm_loss is not None else output
982
+
983
+ return CausalLMOutputWithCrossAttentions(
984
+ loss=lm_loss,
985
+ logits=prediction_scores,
986
+ past_key_values=outputs.past_key_values,
987
+ hidden_states=outputs.hidden_states,
988
+ attentions=outputs.attentions,
989
+ cross_attentions=outputs.cross_attentions,
990
+ )
991
+
992
+ def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs):
993
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
994
+ if attention_mask is None:
995
+ attention_mask = input_ids.new_ones(input_ids.shape)
996
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
997
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
998
+
999
+ # cut decoder_input_ids if past is used
1000
+ if past is not None:
1001
+ input_ids = input_ids[:, -1:]
1002
+
1003
+ return {
1004
+ "input_ids": input_ids,
1005
+ "query_embeds": query_embeds,
1006
+ "attention_mask": attention_mask,
1007
+ "past_key_values": past,
1008
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1009
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1010
+ "is_decoder": True,
1011
+ }
1012
+
1013
+ def _reorder_cache(self, past, beam_idx):
1014
+ reordered_past = ()
1015
+ for layer_past in past:
1016
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1017
+ return reordered_past
1018
+
1019
+
1020
+ class BertForMaskedLM(BertPreTrainedModel):
1021
+
1022
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1023
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1024
+
1025
+ def __init__(self, config):
1026
+ super().__init__(config)
1027
+
1028
+ self.bert = BertModel(config, add_pooling_layer=False)
1029
+ self.cls = BertOnlyMLMHead(config)
1030
+
1031
+ self.init_weights()
1032
+
1033
+ def get_output_embeddings(self):
1034
+ return self.cls.predictions.decoder
1035
+
1036
+ def set_output_embeddings(self, new_embeddings):
1037
+ self.cls.predictions.decoder = new_embeddings
1038
+
1039
+ def forward(
1040
+ self,
1041
+ input_ids=None,
1042
+ attention_mask=None,
1043
+ position_ids=None,
1044
+ head_mask=None,
1045
+ query_embeds=None,
1046
+ encoder_hidden_states=None,
1047
+ encoder_attention_mask=None,
1048
+ labels=None,
1049
+ output_attentions=None,
1050
+ output_hidden_states=None,
1051
+ return_dict=None,
1052
+ return_logits=False,
1053
+ is_decoder=False,
1054
+ ):
1055
+ r"""
1056
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1057
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1058
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1059
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1060
+ """
1061
+
1062
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1063
+
1064
+ outputs = self.bert(
1065
+ input_ids,
1066
+ attention_mask=attention_mask,
1067
+ position_ids=position_ids,
1068
+ head_mask=head_mask,
1069
+ query_embeds=query_embeds,
1070
+ encoder_hidden_states=encoder_hidden_states,
1071
+ encoder_attention_mask=encoder_attention_mask,
1072
+ output_attentions=output_attentions,
1073
+ output_hidden_states=output_hidden_states,
1074
+ return_dict=return_dict,
1075
+ is_decoder=is_decoder,
1076
+ )
1077
+
1078
+ if query_embeds is not None:
1079
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1080
+ prediction_scores = self.cls(sequence_output)
1081
+
1082
+ if return_logits:
1083
+ return prediction_scores
1084
+
1085
+ masked_lm_loss = None
1086
+ if labels is not None:
1087
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1088
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1089
+
1090
+ if not return_dict:
1091
+ output = (prediction_scores,) + outputs[2:]
1092
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1093
+
1094
+ return MaskedLMOutput(
1095
+ loss=masked_lm_loss,
1096
+ logits=prediction_scores,
1097
+ hidden_states=outputs.hidden_states,
1098
+ attentions=outputs.attentions,
1099
+ )
1100
+
1101
+
1102
+ class Qformer(nn.Module):
1103
+ def __init__(self, model_args, vision_tower):
1104
+ super().__init__()
1105
+
1106
+ self.depth = model_args.mm_qformer_depth
1107
+ self.num_latents = model_args.mm_qformer_latents
1108
+ self.pretrained = model_args.mm_qformer_pretrained
1109
+
1110
+ self.Qformer, self.query_tokens, self.ln_vision = self.build_Qformer(vision_tower.hidden_size, self.depth, self.num_latents)
1111
+
1112
+ if self.pretrained is not None:
1113
+ pretrained_dict = torch.load(self.pretrained, map_location="cpu")["model"]
1114
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith("t5_proj")}
1115
+ self.load_state_dict(pretrained_dict)
1116
+
1117
+ def build_Qformer(self, vision_width, cross_attention_freq, num_query_token):
1118
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
1119
+ encoder_config.encoder_width = vision_width
1120
+ # insert cross-attention layer every other block
1121
+ encoder_config.add_cross_attention = True
1122
+ encoder_config.cross_attention_freq = cross_attention_freq
1123
+ encoder_config.query_length = num_query_token
1124
+ Qformer = BertLMHeadModel(config=encoder_config)
1125
+ query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size))
1126
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
1127
+ Qformer.cls = None
1128
+ Qformer.bert.embeddings.word_embeddings = None
1129
+ Qformer.bert.embeddings.position_embeddings = None
1130
+ for layer in Qformer.bert.encoder.layer:
1131
+ layer.output = None
1132
+ layer.intermediate = None
1133
+ return Qformer, query_tokens, nn.LayerNorm(vision_width)
1134
+
1135
+ def forward(self, image_features, *args, **kwargs):
1136
+ x = self.ln_vision(image_features)
1137
+ image_atts = torch.ones(x.size()[:-1], dtype=torch.long).to(x.device)
1138
+
1139
+ query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
1140
+ query_output = self.Qformer.bert(
1141
+ query_embeds=query_tokens,
1142
+ encoder_hidden_states=x,
1143
+ encoder_attention_mask=image_atts,
1144
+ return_dict=True,
1145
+ )
1146
+
1147
+ return query_output.last_hidden_state
1148
+
1149
+ @property
1150
+ def hidden_size(self):
1151
+ return 768
1152
+
1153
+ @property
1154
+ def config(self):
1155
+ return {
1156
+ "mm_resampler_type": "qformer",
1157
+ "mm_qformer_depth": self.depth,
1158
+ "mm_qformer_latents": self.num_latents,
1159
+ "mm_qformer_pretrained": self.pretrained,
1160
+ }
multimodal_resampler/spatial_pool.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+
6
+ class SpatialPool(nn.Module):
7
+ def __init__(self, model_args, vision_tower):
8
+ super().__init__()
9
+
10
+ self.mode = model_args.mm_spatial_pool_mode
11
+ self.stride = model_args.mm_spatial_pool_stride
12
+ self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size)
13
+
14
+ if self.mode == "average":
15
+ self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride)
16
+ elif self.mode == "max":
17
+ self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride)
18
+ elif self.mode == "conv":
19
+ self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride)
20
+ else:
21
+ raise ValueError(f"Unknown pooling mode: {self.pool}.")
22
+
23
+ def forward(self, image_features, images, *args, **kwargs):
24
+ ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2]))
25
+ ori_H = int(ori_W * images.shape[2] // images.shape[3])
26
+
27
+ B, _, F = image_features.shape
28
+
29
+ image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2)
30
+ image_features_spatial_pool = self.pool(image_features_spatial)
31
+
32
+ return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
33
+
34
+ @property
35
+ def config(self):
36
+ return {
37
+ "mm_resampler_type": "spatial_pool",
38
+ "mm_spatial_pool_stride": self.stride,
39
+ "mm_spatial_pool_mode": self.mode,
40
+ "mm_spatial_pool_out_channels": self.out_channels,
41
+ }
42
+
43
+ @property
44
+ def hidden_size(self):
45
+ return self.out_channels