WishArdently commited on
Commit
edf2ce7
·
verified ·
1 Parent(s): 809e672

Upload InternVideo2Stage2VideoEncoder

Browse files
Files changed (8) hide show
  1. config.json +2 -2
  2. config.py +1 -1
  3. flash_attention_class.py +74 -0
  4. internvideo2.py +780 -0
  5. internvideo2_stage2.py +101 -0
  6. model.py +3 -3
  7. model.safetensors +1 -1
  8. pos_embed.py +299 -0
config.json CHANGED
@@ -102,7 +102,7 @@
102
  "num_frames": 8,
103
  "only_mask": true,
104
  "patch_size": 14,
105
- "pretrained": "/home/bingxing2/home/scx7l3k/linanxi/workspace/low_level/Encoders/InternVideo2-stage2_1b-224p-f4.pt",
106
  "sep_image_video_pos_embed": true,
107
  "tubelet_size": 1,
108
  "use_checkpoint": false,
@@ -156,7 +156,7 @@
156
  "tokenizer": null,
157
  "torch_dtype": "float16",
158
  "train_file": "available_corpus[\"pretrain_example_data_1B\"]",
159
- "transformers_version": "4.42.4",
160
  "use_bf16": true,
161
  "use_flash_sdp": false,
162
  "use_half_precision": false,
 
102
  "num_frames": 8,
103
  "only_mask": true,
104
  "patch_size": 14,
105
+ "pretrained": "/home/linanxi/InternVideo/checkpoints/InternVideo2-stage2_1b-224p-f4/InternVideo2-stage2_1b-224p-f4.pt",
106
  "sep_image_video_pos_embed": true,
107
  "tubelet_size": 1,
108
  "use_checkpoint": false,
 
156
  "tokenizer": null,
157
  "torch_dtype": "float16",
158
  "train_file": "available_corpus[\"pretrain_example_data_1B\"]",
159
+ "transformers_version": "4.47.0",
160
  "use_bf16": true,
161
  "use_flash_sdp": false,
162
  "use_half_precision": false,
config.py CHANGED
@@ -132,7 +132,7 @@ class InternVideo2Config(PretrainedConfig):
132
  "clip_norm_type": "l2",
133
  "clip_return_layer": 6,
134
  "clip_student_return_interval": 1,
135
- "pretrained": "/home/bingxing2/home/scx7l3k/linanxi/workspace/low_level/Encoders/InternVideo2-stage2_1b-224p-f4.pt",
136
  "use_checkpoint": False,
137
  "checkpoint_num": 40,
138
  "use_flash_attn": True,
 
132
  "clip_norm_type": "l2",
133
  "clip_return_layer": 6,
134
  "clip_student_return_interval": 1,
135
+ "pretrained": "/home/linanxi/InternVideo/checkpoints/InternVideo2-stage2_1b-224p-f4/InternVideo2-stage2_1b-224p-f4.pt",
136
  "use_checkpoint": False,
137
  "checkpoint_num": 40,
138
  "use_flash_attn": True,
flash_attention_class.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from einops import rearrange
5
+
6
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
7
+ from flash_attn.bert_padding import unpad_input, pad_input
8
+
9
+
10
+ class FlashAttention(nn.Module):
11
+ """Implement the scaled dot product attention with softmax.
12
+ Arguments
13
+ ---------
14
+ softmax_scale: The temperature to use for the softmax attention.
15
+ (default: 1/sqrt(d_keys) where d_keys is computed at
16
+ runtime)
17
+ attention_dropout: The dropout rate to apply to the attention
18
+ (default: 0.0)
19
+ """
20
+
21
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
22
+ super().__init__()
23
+ self.softmax_scale = softmax_scale
24
+ self.dropout_p = attention_dropout
25
+
26
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
27
+ max_s=None, need_weights=False):
28
+ """Implements the multihead softmax attention.
29
+ Arguments
30
+ ---------
31
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
32
+ if unpadded: (nnz, 3, h, d)
33
+ key_padding_mask: a bool tensor of shape (B, S)
34
+ """
35
+
36
+ # qkv = qkv.to(torch.float16)
37
+
38
+ assert not need_weights
39
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
40
+ assert qkv.is_cuda
41
+
42
+ if cu_seqlens is None:
43
+ batch_size = qkv.shape[0]
44
+ seqlen = qkv.shape[1]
45
+ if key_padding_mask is None:
46
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
47
+ max_s = seqlen
48
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
49
+ device=qkv.device)
50
+ output = flash_attn_varlen_qkvpacked_func(
51
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
52
+ softmax_scale=self.softmax_scale, causal=causal
53
+ )
54
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
55
+ else:
56
+ nheads = qkv.shape[-2]
57
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
58
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
59
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
60
+ output_unpad = flash_attn_varlen_qkvpacked_func(
61
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
62
+ softmax_scale=self.softmax_scale, causal=causal
63
+ )
64
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
65
+ indices, batch_size, seqlen),
66
+ 'b s (h d) -> b s h d', h=nheads)
67
+ else:
68
+ assert max_s is not None
69
+ output = flash_attn_varlen_qkvpacked_func(
70
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
71
+ softmax_scale=self.softmax_scale, causal=causal
72
+ )
73
+
74
+ return output, None
internvideo2.py ADDED
@@ -0,0 +1,780 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
5
+ from torch import nn
6
+
7
+ import torch.utils.checkpoint as checkpoint
8
+ from functools import partial
9
+ from einops import rearrange
10
+
11
+ from .pos_embed import get_3d_sincos_pos_embed, get_2d_sincos_pos_embed, get_1d_sincos_pos_embed, interpolate_pos_embed_internvideo2
12
+ from .flash_attention_class import FlashAttention
13
+
14
+ from transformers.utils import logging as error_logging
15
+
16
+ # Set up logging
17
+ error_logging.set_verbosity_error()
18
+
19
+ try:
20
+ from flash_attn.modules.mlp import Mlp as FusedMLP
21
+ except:
22
+ pass
23
+
24
+ try:
25
+ from flash_attn.ops.rms_norm import DropoutAddRMSNorm
26
+ except:
27
+ pass
28
+
29
+
30
+ class CrossAttention(nn.Module):
31
+ def __init__(
32
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
33
+ proj_drop=0., attn_head_dim=None, out_dim=None):
34
+ super().__init__()
35
+ if out_dim is None:
36
+ out_dim = dim
37
+ self.num_heads = num_heads
38
+ head_dim = dim // num_heads
39
+ if attn_head_dim is not None:
40
+ head_dim = attn_head_dim
41
+ all_head_dim = head_dim * self.num_heads
42
+ self.scale = qk_scale or head_dim ** -0.5
43
+ assert all_head_dim == dim
44
+
45
+ self.q = nn.Linear(dim, all_head_dim, bias=False)
46
+ self.k = nn.Linear(dim, all_head_dim, bias=False)
47
+ self.v = nn.Linear(dim, all_head_dim, bias=False)
48
+
49
+ if qkv_bias:
50
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
51
+ self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
52
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
53
+ else:
54
+ self.q_bias = None
55
+ self.k_bias = None
56
+ self.v_bias = None
57
+
58
+ self.attn_drop = nn.Dropout(attn_drop)
59
+ self.proj = nn.Linear(all_head_dim, out_dim)
60
+ self.proj_drop = nn.Dropout(proj_drop)
61
+
62
+ def forward(self, x, k=None, v=None):
63
+ B, N, C = x.shape
64
+ N_k = k.shape[1]
65
+ N_v = v.shape[1]
66
+
67
+ q_bias, k_bias, v_bias = None, None, None
68
+ if self.q_bias is not None:
69
+ q_bias = self.q_bias
70
+ k_bias = self.k_bias
71
+ v_bias = self.v_bias
72
+
73
+ q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
74
+ q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
75
+
76
+ k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
77
+ k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
78
+
79
+ v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
80
+ v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
81
+
82
+ q = q * self.scale
83
+ attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
84
+
85
+ attn = attn.softmax(dim=-1)
86
+ attn = self.attn_drop(attn)
87
+
88
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
89
+ x = self.proj(x)
90
+ x = self.proj_drop(x)
91
+
92
+ return x
93
+
94
+
95
+ class AttentiveBlock(nn.Module):
96
+
97
+ def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
98
+ drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
99
+ super().__init__()
100
+
101
+ self.norm1_q = norm_layer(dim)
102
+ self.norm1_k = norm_layer(dim)
103
+ self.norm1_v = norm_layer(dim)
104
+ self.cross_attn = CrossAttention(
105
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
106
+ proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
107
+
108
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
109
+
110
+ def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
111
+ x_q = self.norm1_q(x_q + pos_q)
112
+ x_k = self.norm1_k(x_kv + pos_k)
113
+ x_v = self.norm1_v(x_kv)
114
+ x = self.cross_attn(x_q, k=x_k, v=x_v)
115
+
116
+ return x
117
+
118
+
119
+ class AttentionPoolingBlock(AttentiveBlock):
120
+
121
+ def forward(self, x):
122
+ # x_q = x.mean(1, keepdim=True)
123
+ x_q = x
124
+ x_kv, pos_q, pos_k = x, 0, 0
125
+ x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
126
+ x = x.squeeze(1)
127
+ return x
128
+
129
+
130
+ class RMSNorm(nn.Module):
131
+ def __init__(self, hidden_size, eps=1e-6):
132
+ super().__init__()
133
+ self.weight = nn.Parameter(torch.ones(hidden_size))
134
+ self.variance_epsilon = eps
135
+
136
+ def forward(self, hidden_states):
137
+ input_dtype = hidden_states.dtype
138
+ hidden_states = hidden_states.to(torch.float32)
139
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
140
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
141
+ return self.weight * hidden_states.to(input_dtype)
142
+
143
+
144
+ class LayerScale(nn.Module):
145
+ def __init__(self, dim, init_values=1e-5, inplace=False, force_fp32=False):
146
+ super().__init__()
147
+ self.inplace = inplace
148
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
149
+ self.force_fp32 = force_fp32
150
+
151
+ @torch.cuda.amp.autocast(enabled=False)
152
+ def forward(self, x):
153
+ if self.force_fp32:
154
+ output_type = x.dtype
155
+ out = x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float()
156
+ return out.to(dtype=output_type)
157
+ else:
158
+ out = x.mul_(self.gamma) if self.inplace else x * self.gamma
159
+ return out
160
+
161
+
162
+ class Attention(nn.Module):
163
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False,
164
+ causal=False, norm_layer=nn.LayerNorm, qk_normalization=False, use_fused_rmsnorm=False):
165
+ super().__init__()
166
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
167
+ self.num_heads = num_heads
168
+ head_dim = dim // num_heads
169
+ self.scale = head_dim ** -0.5
170
+
171
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
172
+ self.attn_drop = nn.Dropout(attn_drop)
173
+ self.proj = nn.Linear(dim, dim)
174
+ self.proj_drop = nn.Dropout(proj_drop)
175
+
176
+ self.use_flash_attn = use_flash_attn
177
+ if use_flash_attn:
178
+ self.causal = causal
179
+ self.inner_attn = FlashAttention(attention_dropout=attn_drop)
180
+
181
+ self.qk_normalization = qk_normalization
182
+ self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity()
183
+ self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity()
184
+ self.use_fused_rmsnorm = use_fused_rmsnorm
185
+
186
+ def _naive_attn(self, x):
187
+ B, N, C = x.shape
188
+ # print(x.shape, torch.cuda.memory_allocated(), torch.cuda.memory_allocated())
189
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
190
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
191
+
192
+ if self.qk_normalization:
193
+ B_, H_, N_, D_ = q.shape
194
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
195
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
196
+
197
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
198
+ # attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16
199
+ attn = attn.softmax(dim=-1)
200
+ attn = self.attn_drop(attn)
201
+ # print(torch.cuda.memory_allocated(), torch.cuda.memory_allocated())
202
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
203
+ x = self.proj(x)
204
+ x = self.proj_drop(x)
205
+ return x
206
+
207
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
208
+
209
+ qkv = self.qkv(x)
210
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
211
+
212
+ if self.qk_normalization:
213
+ q, k, v = qkv.unbind(2)
214
+ if self.use_fused_rmsnorm:
215
+ q = self.q_norm(q.flatten(-2, -1))[0].view(q.shape)
216
+ k = self.k_norm(k.flatten(-2, -1))[0].view(k.shape)
217
+ else:
218
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
219
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
220
+ qkv = torch.stack([q, k, v], dim=2)
221
+
222
+ context, _ = self.inner_attn(
223
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
224
+ )
225
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
226
+ outs = self.proj_drop(outs)
227
+ return outs
228
+
229
+ def forward(self, x):
230
+ x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x)
231
+ return x
232
+
233
+
234
+ class Mlp(nn.Module):
235
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
236
+ """
237
+
238
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
239
+ bias=True, drop=0.):
240
+ super().__init__()
241
+ out_features = out_features or in_features
242
+ hidden_features = hidden_features or in_features
243
+ bias = to_2tuple(bias)
244
+ drop_probs = to_2tuple(drop)
245
+
246
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
247
+ self.act = act_layer()
248
+ self.drop1 = nn.Dropout(drop_probs[0])
249
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
250
+ self.drop2 = nn.Dropout(drop_probs[1])
251
+
252
+ def forward(self, x):
253
+ x = self.fc1(x)
254
+ x = self.act(x)
255
+ x = self.drop1(x)
256
+ x = self.fc2(x)
257
+ x = self.drop2(x)
258
+ return x
259
+
260
+
261
+ class Block(nn.Module):
262
+
263
+ def __init__(
264
+ self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
265
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, use_fused_mlp=False,
266
+ fused_mlp_heuristic=1, with_cp=False, qk_normalization=False, layerscale_no_force_fp32=False,
267
+ use_fused_rmsnorm=False):
268
+ super().__init__()
269
+
270
+ self.norm1 = norm_layer(dim)
271
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
272
+ use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer,
273
+ qk_normalization=qk_normalization,
274
+ use_fused_rmsnorm=use_fused_rmsnorm)
275
+ self.ls1 = LayerScale(dim, init_values=init_values,
276
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
277
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
278
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
279
+
280
+ self.norm2 = norm_layer(dim)
281
+ mlp_hidden_dim = int(dim * mlp_ratio)
282
+ if use_fused_mlp:
283
+ # self.mlp = FusedMLP(in_features=dim, hidden_features=mlp_hidden_dim, heuristic=fused_mlp_heuristic)
284
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
285
+ else:
286
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
287
+ self.ls2 = LayerScale(dim, init_values=init_values,
288
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
289
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
290
+
291
+ self.with_cp = with_cp
292
+ self.use_fused_rmsnorm = use_fused_rmsnorm
293
+
294
+ def forward(self, x, residual=None):
295
+
296
+ def _inner_forward(x, residual=None):
297
+ if self.use_fused_rmsnorm:
298
+ x, residual = self.norm1(x, residual)
299
+ x = self.drop_path1(self.ls1(self.attn(x)))
300
+ x, residual = self.norm2(x, residual)
301
+ x = self.drop_path2(self.ls2(self.mlp(x)))
302
+ return x, residual
303
+ else:
304
+ assert residual is None
305
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
306
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
307
+ return x
308
+
309
+ if self.with_cp:
310
+ # print(f"\033[31m use_checkpoint [0m")
311
+ return checkpoint.checkpoint(_inner_forward, x, residual)
312
+ else:
313
+ return _inner_forward(x, residual=residual)
314
+
315
+
316
+ class PatchEmbed(nn.Module):
317
+ """ 3D Image to Patch Embedding
318
+ """
319
+
320
+ def __init__(
321
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,
322
+ num_frames=8, tubelet_size=1, norm_layer=None
323
+ ):
324
+ super().__init__()
325
+ img_size = to_2tuple(img_size)
326
+ patch_size = to_2tuple(patch_size)
327
+ self.img_size = img_size
328
+ self.patch_size = patch_size
329
+ self.grid_size = (
330
+ num_frames // tubelet_size,
331
+ img_size[0] // patch_size[0],
332
+ img_size[1] // patch_size[1]
333
+ ) # (T, H, W)
334
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
335
+ self.num_img_patches = self.grid_size[1] * self.grid_size[2]
336
+
337
+ self.proj = nn.Conv3d(
338
+ in_channels=in_chans, out_channels=embed_dim,
339
+ kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
340
+ stride=(tubelet_size, patch_size[0], patch_size[1])
341
+ )
342
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
343
+
344
+ def forward(self, x):
345
+ x = self.proj(x)
346
+ x = x.flatten(3).permute(0, 2, 3, 1) # B x C x T x HW => B x T x HW x C
347
+ x = self.norm(x)
348
+ return x
349
+
350
+
351
+ class Linear_Decoder(nn.Module):
352
+ def __init__(self, in_channels=1408, out_channels=3200,
353
+ norm_layer=nn.LayerNorm, clip_norm_type='l2'):
354
+ super().__init__()
355
+ self.clip_norm_type = clip_norm_type
356
+ # logger.info(f'Normalization Type: {clip_norm_type}')
357
+
358
+ self.head = nn.Linear(in_channels, out_channels)
359
+ self.norm = norm_layer(out_channels)
360
+
361
+ self.apply(self._init_weights)
362
+
363
+ def _init_weights(self, m):
364
+ if isinstance(m, nn.Linear):
365
+ nn.init.xavier_uniform_(m.weight)
366
+ if isinstance(m, nn.Linear) and m.bias is not None:
367
+ nn.init.constant_(m.bias, 0)
368
+ elif isinstance(m, nn.LayerNorm):
369
+ nn.init.constant_(m.bias, 0)
370
+ nn.init.constant_(m.weight, 1.0)
371
+
372
+ def forward(self, x):
373
+ x = self.norm(self.head(x))
374
+
375
+ if self.clip_norm_type == 'l2':
376
+ x = x / x.norm(dim=-1, keepdim=True)
377
+ elif self.clip_norm_type == 'none':
378
+ pass
379
+ else:
380
+ raise NotImplementedError
381
+
382
+ return x
383
+
384
+
385
+ class PretrainInternVideo2(nn.Module):
386
+ def __init__(
387
+ self,
388
+ in_chans: int = 3,
389
+ patch_size: int = 14,
390
+ img_size: int = 224,
391
+ qkv_bias: bool = False,
392
+ drop_path_rate: float = 0.25,
393
+ embed_dim: int = 1408,
394
+ num_heads: int = 16,
395
+ mlp_ratio: float = 48/11,
396
+ init_values: float = 1e-5,
397
+ qk_normalization: bool = True,
398
+ depth: int = 40,
399
+ use_flash_attn: bool = True,
400
+ use_fused_rmsnorm: bool = True,
401
+ use_fused_mlp: bool = True,
402
+ fused_mlp_heuristic: int = 1,
403
+ attn_pool_num_heads: int = 16,
404
+ clip_embed_dim: int = 768,
405
+ layerscale_no_force_fp32: bool = False,
406
+ num_frames: int = 8,
407
+ tubelet_size: int = 1,
408
+ sep_pos_embed: bool = False,
409
+ sep_image_video_pos_embed: bool = False,
410
+ use_checkpoint: bool = False,
411
+ checkpoint_num: int = 0,
412
+ # for unmasked teacher
413
+ clip_teacher_embed_dim: int = 3200,
414
+ clip_teacher_final_dim: int = 768, # if 0, not distill final features
415
+ clip_norm_type: str = 'l2',
416
+ clip_return_layer: int = 1,
417
+ clip_student_return_interval: int = 1,
418
+ ):
419
+ super().__init__()
420
+
421
+ self.num_frames = num_frames
422
+ # print(f'num_frames: {num_frames}')
423
+ self.tubelet_size = tubelet_size
424
+ assert use_flash_attn == use_fused_rmsnorm == use_fused_mlp, 'use_flash_attn, use_fused_rmsnorm and use_fused_mlp should be consistent'
425
+
426
+ self.use_flash_attn = use_flash_attn
427
+ self.embed_dim = embed_dim
428
+
429
+ self.depth = depth
430
+ self.clip_norm_type = clip_norm_type
431
+ self.return_index = []
432
+ for i in range(clip_return_layer):
433
+ self.return_index.append(depth - int(i * clip_student_return_interval) - 1)
434
+ # logger.info(f'Normalization Type: {clip_norm_type}')
435
+ # logger.info(f'Strudent Return Index: {self.return_index}')
436
+
437
+ if use_fused_rmsnorm:
438
+ norm_layer_for_blocks = partial(DropoutAddRMSNorm, eps=1e-6, prenorm=True)
439
+ else:
440
+ norm_layer_for_blocks = partial(RMSNorm, eps=1e-6)
441
+ self.norm_layer_for_blocks = norm_layer_for_blocks
442
+ self.patch_embed = PatchEmbed(
443
+ img_size, patch_size, in_chans, embed_dim,
444
+ num_frames=num_frames, tubelet_size=tubelet_size,
445
+ )
446
+ num_patches = self.patch_embed.num_patches
447
+ num_img_patches = self.patch_embed.num_img_patches
448
+
449
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
450
+
451
+ # stolen from https://github.com/facebookresearch/mae_st/blob/dc072aaaf640d06892e23a33b42223a994efe272/models_vit.py#L65-L73C17
452
+ self.sep_pos_embed = sep_pos_embed
453
+ self.sep_image_video_pos_embed = sep_image_video_pos_embed
454
+ if sep_pos_embed:
455
+ raise NotImplementedError
456
+ else:
457
+ if sep_image_video_pos_embed:
458
+ # logger.info("Use joint position embedding, for image and video we use different pos_embed.")
459
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
460
+ self.img_pos_embed = nn.Parameter(torch.zeros(1, num_img_patches + 1, embed_dim))
461
+ # for CLIP decoder
462
+ self.clip_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
463
+ self.clip_img_pos_embed = nn.Parameter(torch.zeros(1, num_img_patches + 1, embed_dim))
464
+ else:
465
+ # logger.info("Use joint position embedding, for image and video we use same pos_embed.")
466
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
467
+ self.clip_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
468
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
469
+ # choose which layer to use checkpoint
470
+ with_cp_list = [False] * depth
471
+ if use_checkpoint:
472
+ for idx in range(depth):
473
+ if idx < checkpoint_num:
474
+ with_cp_list[idx] = True
475
+ # logger.info(f"Droppath rate: {dpr}")
476
+ # logger.info(f"Checkpoint list: {with_cp_list}")
477
+
478
+ self.blocks = nn.ModuleList([
479
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias,
480
+ norm_layer=norm_layer_for_blocks,
481
+ drop_path=dpr[i], init_values=init_values, attn_drop=0.,
482
+ use_flash_attn=use_flash_attn, use_fused_mlp=use_fused_mlp,
483
+ fused_mlp_heuristic=fused_mlp_heuristic,
484
+ with_cp=with_cp_list[i],
485
+ qk_normalization=qk_normalization,
486
+ layerscale_no_force_fp32=layerscale_no_force_fp32,
487
+ use_fused_rmsnorm=use_fused_rmsnorm)
488
+ for i in range(depth)])
489
+ self.clip_projector = AttentionPoolingBlock(
490
+ dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
491
+ drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
492
+
493
+ # CLIP decoder
494
+ self.clip_decoder = nn.ModuleList([
495
+ Linear_Decoder(
496
+ in_channels=embed_dim,
497
+ out_channels=clip_teacher_embed_dim,
498
+ norm_layer=partial(nn.LayerNorm, eps=1e-5),
499
+ clip_norm_type=clip_norm_type
500
+ ) for _ in range(clip_return_layer)
501
+ ])
502
+ self.final_clip_decoder = nn.Identity()
503
+ if clip_teacher_final_dim > 0:
504
+ self.final_clip_decoder = Linear_Decoder(
505
+ in_channels=clip_embed_dim,
506
+ out_channels=clip_teacher_final_dim,
507
+ norm_layer=partial(nn.LayerNorm, eps=1e-5),
508
+ clip_norm_type=clip_norm_type
509
+ )
510
+
511
+ self.init_pos_embed()
512
+ trunc_normal_(self.cls_token, std=.02)
513
+ self.apply(self._init_weights)
514
+ self.fix_init_weight()
515
+
516
+ def init_pos_embed(self):
517
+ # logger.info("Init pos_embed from sincos pos_embed")
518
+ if self.sep_pos_embed:
519
+ raise NotImplementedError
520
+ else:
521
+ # trunc_normal_(self.pos_embed, std=.02)
522
+ # trunc_normal_(self.clip_pos_embed, std=.02)
523
+ pos_embed = get_3d_sincos_pos_embed(
524
+ self.pos_embed.shape[-1],
525
+ self.patch_embed.grid_size[1], # height & weight
526
+ self.patch_embed.grid_size[0], # t_size
527
+ cls_token=True
528
+ )
529
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
530
+ self.clip_pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
531
+
532
+ if self.sep_image_video_pos_embed:
533
+ img_pos_embed = get_3d_sincos_pos_embed(
534
+ self.pos_embed.shape[-1],
535
+ self.patch_embed.grid_size[1], # height & weight
536
+ 1,
537
+ cls_token=True
538
+ )
539
+ self.img_pos_embed.data.copy_(torch.from_numpy(img_pos_embed).float().unsqueeze(0))
540
+ self.clip_img_pos_embed.data.copy_(torch.from_numpy(img_pos_embed).float().unsqueeze(0))
541
+
542
+ def _init_weights(self, m):
543
+ if isinstance(m, nn.Linear):
544
+ trunc_normal_(m.weight, std=.02)
545
+ if isinstance(m, nn.Linear) and m.bias is not None:
546
+ nn.init.constant_(m.bias, 0)
547
+ elif isinstance(m, nn.LayerNorm):
548
+ nn.init.constant_(m.bias, 0)
549
+ nn.init.constant_(m.weight, 1.0)
550
+
551
+ def fix_init_weight(self):
552
+ def rescale(param, layer_id):
553
+ param.div_(math.sqrt(2.0 * layer_id))
554
+
555
+ for layer_id, layer in enumerate(self.blocks):
556
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
557
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
558
+
559
+ @property
560
+ def dtype(self):
561
+ return self.patch_embed.proj.weight.dtype
562
+
563
+ def get_num_layers(self):
564
+ return len(self.blocks)
565
+
566
+ @torch.jit.ignore
567
+ def no_weight_decay(self):
568
+ return {
569
+ 'pos_embed',
570
+ 'pos_embed_spatial',
571
+ 'pos_embed_temporal',
572
+ 'pos_embed_cls',
573
+ 'img_pos_embed',
574
+ 'cls_token',
575
+ 'clip_pos_embed',
576
+ 'clip_pos_embed_spatial',
577
+ 'clip_pos_embed_temporal',
578
+ 'clip_pos_embed_cls',
579
+ 'clip_img_pos_embed'
580
+ }
581
+
582
+ # @torch.cuda.amp.autocast(enabled=False)
583
+ def forward(self, x, mask=None, use_image=False, x_vis_return_idx=-1, x_vis_only=False):
584
+ # print(0, x.shape)
585
+ x = self.patch_embed(x.type(self.dtype))
586
+ # print(f"x.shape: {x.shape} x.dtype: {x.dtype}, model.dtype: {self.dtype}")
587
+ B, T, L, C = x.shape # T: temporal; L: spatial
588
+ x = x.view([B, T * L, C]) # (B, T * L, C)
589
+
590
+ # append cls token
591
+ cls_tokens = self.cls_token.expand(B, -1, -1)
592
+ x = torch.cat((cls_tokens, x), dim=1) # (B, T * L + 1, C)
593
+ # print(1, x.shape)
594
+
595
+ # add pos_embed
596
+ if self.sep_pos_embed:
597
+ raise NotImplementedError
598
+ else:
599
+ if use_image:
600
+ # print('use image') # No.
601
+ if self.sep_image_video_pos_embed:
602
+ pos_embed = self.img_pos_embed
603
+ else:
604
+ # (1, num_img_patches + 1, embed_dim)
605
+ # print('origin pos_embed.shape:', self.pos_embed.shape)
606
+ cls_pos_embed = self.pos_embed[:, 0:1, :]
607
+ # print('cls_pos_embed.shape:', cls_pos_embed.shape)
608
+
609
+ img_pos_embed = self.pos_embed[:, 1:, :].view(1, self.num_frames, self.patch_embed.num_patches // self.num_frames, self.embed_dim).mean(dim=1)
610
+ # print('img_pos_embed.shape:', img_pos_embed.shape)
611
+
612
+ pos_embed = torch.cat([cls_pos_embed, img_pos_embed], dim=1)
613
+ # print('final img_pos_embed.shape:', pos_embed.shape)
614
+ else:
615
+ pos_embed = self.pos_embed
616
+ pos_embed = pos_embed[:, :x.shape[1], :]
617
+ x = x + pos_embed
618
+
619
+ # mask tokens, ~mask means visible
620
+ if mask is not None:
621
+ x = x[~mask].reshape(B, -1, C)
622
+ else:
623
+ x = x.reshape(B, -1, C)
624
+ residual = None
625
+ x_clip = []
626
+ for idx, blk in enumerate(self.blocks):
627
+ if isinstance(x, tuple) and len(x) == 2:
628
+ x, residual = x
629
+ # print(f"\033[31m这是{idx}, {x.shape}\033[0m")
630
+ x = blk(x, residual=residual)
631
+ # return intermediate features
632
+ if idx in self.return_index:
633
+ if isinstance(x, tuple) and len(x) == 2:
634
+ tmp_x, tmp_residual = x
635
+ if residual is not None:
636
+ x_clip.append(tmp_x + tmp_residual)
637
+ else:
638
+ x_clip.append(x)
639
+ if idx == (self.depth + x_vis_return_idx):
640
+ # print(f'idx = {idx} len(self.blocks)={len(self.blocks)}')
641
+ break
642
+
643
+ if isinstance(x, tuple) and len(x) == 2:
644
+ x, residual = x
645
+ if residual is not None:
646
+ x = x + residual
647
+
648
+ x_vis = x
649
+ # print(f'x_vis.shape:{x_vis.shape}')
650
+ if x_vis_only:
651
+ return x_vis
652
+
653
+ x_pool_vis = self.clip_projector(x_vis)
654
+ x_align = self.final_clip_decoder(x_pool_vis)
655
+ # print(3, x_pool_vis.shape)
656
+ # print(4, x_align.shape)
657
+
658
+ # align CLIP
659
+ x_clip = torch.stack(x_clip)
660
+ K, B, _, C_CLIP = x_clip.shape
661
+ # print(5, x_clip.shape)
662
+ # add pos_embed
663
+ if self.sep_pos_embed:
664
+ raise NotImplementedError
665
+ else:
666
+ if use_image:
667
+ if self.sep_image_video_pos_embed:
668
+ clip_pos_embed = self.clip_img_pos_embed
669
+ else:
670
+ # (1, num_img_patches + 1, embed_dim)
671
+ # print('origin pos_embed.shape:', self.pos_embed.shape)
672
+ clip_cls_pos_embed = self.clip_pos_embed[:, 0:1, :]
673
+ # print('cls_pos_embed.shape:', cls_pos_embed.shape)
674
+
675
+ clip_img_pos_embed = self.clip_pos_embed[:, 1:, :].view(1, self.num_frames, self.patch_embed.num_patches // self.num_frames, self.embed_dim).mean(dim=1)
676
+ # print('img_pos_embed.shape:', img_pos_embed.shape)
677
+
678
+ clip_pos_embed = torch.cat([clip_cls_pos_embed, clip_img_pos_embed], dim=1)
679
+ # print('final img_pos_embed.shape:', pos_embed.shape)
680
+
681
+ else:
682
+ clip_pos_embed = self.clip_pos_embed
683
+
684
+ clip_pos_embed = clip_pos_embed.repeat(B, 1, 1)
685
+ if mask is not None:
686
+ x_clip = x_clip + clip_pos_embed[~mask].view(B, -1, C_CLIP).unsqueeze(0).repeat(K, 1, 1, 1)
687
+ else:
688
+ clip_pos_embed = clip_pos_embed.unsqueeze(0).repeat(K, 1, 1, 1)
689
+ clip_pos_embed = clip_pos_embed[:, :, :x_clip.shape[2], :]
690
+ x_clip = x_clip + clip_pos_embed
691
+
692
+ # CLIP decoder
693
+ x_clip_align = []
694
+ for idx, clip_decoder in enumerate(self.clip_decoder):
695
+ x_clip_align.append(clip_decoder(x_clip[idx]))
696
+ x_clip_align = torch.stack(x_clip_align)
697
+
698
+ # print(f'x_vis.shape:{x_vis.shape}, x_pool_vis.shape:{x_pool_vis.shape}')
699
+ return x_vis, x_pool_vis, x_clip_align, x_align
700
+
701
+
702
+ def pretrain_internvideo2_1b_patch14_224(config):
703
+ # print(config.vision_encoder.num_frames)
704
+ model = PretrainInternVideo2(
705
+ in_chans=3, img_size=224, patch_size=14,
706
+ embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
707
+ clip_embed_dim=config.vision_encoder.clip_embed_dim,
708
+ attn_pool_num_heads=16, qkv_bias=False,
709
+ drop_path_rate=0.25,
710
+ init_values=0.00001,
711
+ qk_normalization=True,
712
+ use_flash_attn=config.vision_encoder.get('use_flash_attn', True),
713
+ use_fused_rmsnorm=config.vision_encoder.get('use_fused_rmsnorm', True),
714
+ use_fused_mlp=config.vision_encoder.get('use_fused_mlp', True),
715
+ fused_mlp_heuristic=1,
716
+ layerscale_no_force_fp32=False,
717
+ num_frames=config.vision_encoder.num_frames,
718
+ tubelet_size=config.vision_encoder.tubelet_size,
719
+ sep_pos_embed=False,
720
+ sep_image_video_pos_embed=config.vision_encoder.sep_image_video_pos_embed,
721
+ use_checkpoint=config.vision_encoder.use_checkpoint,
722
+ checkpoint_num=config.vision_encoder.checkpoint_num,
723
+ clip_teacher_embed_dim=config.vision_encoder.clip_teacher_embed_dim,
724
+ clip_teacher_final_dim=config.vision_encoder.clip_teacher_final_dim,
725
+ clip_norm_type=config.vision_encoder.clip_norm_type,
726
+ clip_return_layer=config.vision_encoder.clip_return_layer,
727
+ clip_student_return_interval=config.vision_encoder.clip_student_return_interval,
728
+ )
729
+
730
+ if config.vision_encoder.pretrained is not None:
731
+ # logger.info(f"Loading pretrained weights from {config.vision_encoder.pretrained}")
732
+ state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu')
733
+ interpolate_pos_embed_internvideo2(state_dict, model, orig_t_size=8)
734
+ message = model.load_state_dict(state_dict, strict=False)
735
+ # logger.info(message)
736
+ else:
737
+ pass
738
+ # logger.info("No pretrained weights!!!")
739
+ return model
740
+
741
+
742
+
743
+ def pretrain_internvideo2_6b_patch14_224(config):
744
+ model = PretrainInternVideo2(
745
+ in_chans=3, img_size=224, patch_size=14,
746
+ embed_dim=3200, depth=48, num_heads=25, mlp_ratio=4,
747
+ clip_embed_dim=config.vision_encoder.clip_embed_dim,
748
+ attn_pool_num_heads=16, qkv_bias=False,
749
+ drop_path_rate=0.3,
750
+ init_values=0.00001,
751
+ qk_normalization=True,
752
+ use_flash_attn=config.vision_encoder.get('use_flash_attn', True),
753
+ use_fused_rmsnorm=config.vision_encoder.get('use_fused_rmsnorm', True),
754
+ use_fused_mlp=config.vision_encoder.get('use_fused_mlp', True),
755
+ fused_mlp_heuristic=1,
756
+ layerscale_no_force_fp32=False,
757
+ num_frames=config.vision_encoder.num_frames,
758
+ tubelet_size=config.vision_encoder.tubelet_size,
759
+ sep_pos_embed=False,
760
+ sep_image_video_pos_embed=config.vision_encoder.sep_image_video_pos_embed,
761
+ use_checkpoint=config.vision_encoder.use_checkpoint,
762
+ checkpoint_num=config.vision_encoder.checkpoint_num,
763
+ clip_teacher_embed_dim=config.vision_encoder.clip_teacher_embed_dim,
764
+ clip_teacher_final_dim=config.vision_encoder.clip_teacher_final_dim,
765
+ clip_norm_type=config.vision_encoder.clip_norm_type,
766
+ clip_return_layer=config.vision_encoder.clip_return_layer,
767
+ clip_student_return_interval=config.vision_encoder.clip_student_return_interval,
768
+ )
769
+
770
+ if config.vision_encoder.pretrained is not None:
771
+ # logger.info(f"Loading pretrained weights from {config.vision_encoder.pretrained}")
772
+ state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu')
773
+ interpolate_pos_embed_internvideo2(state_dict, model, orig_t_size=8)
774
+ msg = model.load_state_dict(state_dict, strict=False)
775
+ # logger.info(msg)
776
+ else:
777
+ pass
778
+ # logger.info("No pretrained weights!!!")
779
+ return model
780
+
internvideo2_stage2.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import json
3
+ import torch
4
+ from torch import nn
5
+ from .config import InternVideo2Config, EasyDict
6
+ from .internvideo2 import pretrain_internvideo2_1b_patch14_224, pretrain_internvideo2_6b_patch14_224
7
+ from transformers.utils import logging
8
+ import warnings
9
+
10
+ warnings.filterwarnings("ignore")
11
+
12
+ class InternVideo2_Stage2(nn.Module):
13
+ """docstring for InternVideo2_Stage2"""
14
+
15
+ def __init__(self, config, is_pretrain=True):
16
+ super(InternVideo2_Stage2, self).__init__()
17
+
18
+ # if isinstance(config, InternVideo2Config):
19
+ # config_str = str(config)
20
+ # config_str = config_str.replace('InternVideo2Config ', '')
21
+ # config_json = json.loads(config_str)
22
+ # config = EasyDict(config_json)
23
+ # self.config = config
24
+
25
+ self.config = config
26
+
27
+ self.is_pretrain = is_pretrain
28
+ self.vision_width = config.model.vision_encoder.clip_embed_dim
29
+ # self.text_width = config.model.text_encoder.d_model
30
+ self.embed_dim = config.model.embed_dim
31
+
32
+ # create modules.
33
+ self.vision_encoder = self.build_vision_encoder()
34
+ if config.model.get("freeze_vision", False):
35
+ self.freeze_vision()
36
+
37
+ self.vision_proj = nn.Linear(self.vision_width, self.embed_dim)
38
+
39
+ self.temp = nn.parameter.Parameter(torch.ones([]) * config.model.temp)
40
+ self.uta_image_only = config.criterion.get('uta_image_only', False)
41
+
42
+ # logger.info(f"uta_image_only={self.uta_image_only}")
43
+
44
+ def freeze_vision(self):
45
+ """freeze vision encoder"""
46
+ for p in self.vision_encoder.parameters():
47
+ p.requires_grad = False
48
+
49
+ def no_weight_decay(self):
50
+ ret = {"temp"}
51
+ ret.update(
52
+ {"vision_encoder." + k for k in self.vision_encoder.no_weight_decay()}
53
+ )
54
+ # ret.update(
55
+ # {"text_encoder." + k for k in self.text_encoder.no_weight_decay()}
56
+ # )
57
+
58
+ return ret
59
+
60
+ @property
61
+ def dtype(self):
62
+ return self.vision_encoder.patch_embed.proj.weight.dtype
63
+
64
+ def encode_vision(self, image):
65
+ """encode image / videos as features.
66
+
67
+ Args:
68
+ image (torch.Tensor): The input images. Shape(B, N, C, H, W)
69
+ test (bool): Whether testing.
70
+
71
+ Returns: tuple.
72
+ - vision_embeds (torch.Tensor): The output features. Shape: [B,N,C].
73
+ - pooled_vision_embeds (torch.Tensor): The pooled output features. Shape: [B,1,C].
74
+ - student_output (torch.Tensor): The features of alignment. Shape: [K,B,N,C].
75
+ - clip_output (torch.Tensor): The features of clip. Shape: [K,B,N,C].
76
+
77
+ """
78
+ T = image.shape[1]
79
+ use_image = True if T == 1 else False
80
+ image = image.permute(0, 2, 1, 3, 4) # [B,N,C,H,W] -> [B,C,N,H,W]
81
+ # whether save temporal dimension
82
+ # keep_temporal=self.config.model.vision_encoder.keep_temporal
83
+ vision_embeds, pooled_vision_embeds, _, _ = self.vision_encoder(
84
+ image, None, use_image)
85
+ return vision_embeds, pooled_vision_embeds
86
+
87
+ def build_vision_encoder(self):
88
+ """build vision encoder
89
+ Returns: (vision_encoder, clip_teacher). Each is a `nn.Module`.
90
+
91
+ """
92
+ encoder_name = self.config.model.vision_encoder.name
93
+ # logger.info(f"Build vision_encoder: {encoder_name}")
94
+ if encoder_name == 'pretrain_internvideo2_1b_patch14_224':
95
+ vision_encoder = pretrain_internvideo2_1b_patch14_224(self.config.model)
96
+ elif encoder_name == 'pretrain_internvideo2_6b_patch14_224':
97
+ vision_encoder = pretrain_internvideo2_6b_patch14_224(self.config.model)
98
+ else:
99
+ raise ValueError(f"Not implemented: {encoder_name}")
100
+ return vision_encoder
101
+
model.py CHANGED
@@ -1,6 +1,6 @@
1
- from internvideo2_stage2 import InternVideo2_Stage2 as IV2S2
2
  from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
3
- from config import InternVideo2Config as config
4
  import warnings
5
  import torch
6
  # from transformers.utils import logging
@@ -19,7 +19,7 @@ class InternVideo2Stage2VideoEncoder(PreTrainedModel):
19
  super().__init__(config)
20
  self.config = config
21
  # print(self.config.model.vision_encoder.num_frames)
22
- self.model = IV2S2(self.config).to(config.device).to(torch.float16)
23
 
24
  def forward(self, x: torch.tensor):
25
  """forward pass
 
1
+ from .internvideo2_stage2 import InternVideo2_Stage2 as IV2S2
2
  from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
3
+ from .config import InternVideo2Config as config
4
  import warnings
5
  import torch
6
  # from transformers.utils import logging
 
19
  super().__init__(config)
20
  self.config = config
21
  # print(self.config.model.vision_encoder.num_frames)
22
+ self.model = IV2S2(self.config).to('cpu').to(torch.float16)
23
 
24
  def forward(self, x: torch.tensor):
25
  """forward pass
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d188c47f1651b851d63195fc8439ba8e816eca9fe068e7685140f41b04638b3f
3
  size 2104856154
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b846fa2b0540df04a40b8e54568a667de8b03c2d2d8c0062aaa4b606a23fc174
3
  size 2104856154
pos_embed.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ # --------------------------------------------------------
8
+ # 3D sine-cosine position embedding
9
+ # References:
10
+ # MVD: https://github.com/ruiwang2021/mvd/blob/main/modeling_finetune.py
11
+ # --------------------------------------------------------
12
+ def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False):
13
+ """
14
+ grid_size: int of the grid height and width
15
+ t_size: int of the temporal size
16
+ return:
17
+ pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
18
+ """
19
+ assert embed_dim % 4 == 0
20
+ embed_dim_spatial = embed_dim // 4 * 3
21
+ embed_dim_temporal = embed_dim // 4
22
+
23
+ # spatial
24
+ grid_h = np.arange(grid_size, dtype=np.float32)
25
+ grid_w = np.arange(grid_size, dtype=np.float32)
26
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
27
+ grid = np.stack(grid, axis=0)
28
+
29
+ grid = grid.reshape([2, 1, grid_size, grid_size])
30
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(
31
+ embed_dim_spatial, grid
32
+ )
33
+
34
+ # temporal
35
+ grid_t = np.arange(t_size, dtype=np.float32)
36
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(
37
+ embed_dim_temporal, grid_t
38
+ )
39
+
40
+ # concate: [T, H, W] order
41
+ pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
42
+ pos_embed_temporal = np.repeat(
43
+ pos_embed_temporal, grid_size**2, axis=1
44
+ ) # [T, H*W, D // 4]
45
+ pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
46
+ pos_embed_spatial = np.repeat(
47
+ pos_embed_spatial, t_size, axis=0
48
+ ) # [T, H*W, D // 4 * 3]
49
+
50
+ pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
51
+ pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
52
+
53
+ if cls_token:
54
+ pos_embed = np.concatenate(
55
+ [np.zeros([1, embed_dim]), pos_embed], axis=0
56
+ )
57
+ return pos_embed
58
+
59
+
60
+ # --------------------------------------------------------
61
+ # 2D sine-cosine position embedding
62
+ # References:
63
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
64
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
65
+ # --------------------------------------------------------
66
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
67
+ """
68
+ grid_size: int of the grid height and width
69
+ return:
70
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
71
+ """
72
+ grid_h = np.arange(grid_size, dtype=np.float32)
73
+ grid_w = np.arange(grid_size, dtype=np.float32)
74
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
75
+ grid = np.stack(grid, axis=0)
76
+
77
+ grid = grid.reshape([2, 1, grid_size, grid_size])
78
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
79
+ if cls_token:
80
+ pos_embed = np.concatenate(
81
+ [np.zeros([1, embed_dim]), pos_embed], axis=0
82
+ )
83
+ return pos_embed
84
+
85
+
86
+ def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False):
87
+ """
88
+ t_size: int of the temporal size
89
+ return:
90
+ pos_embed: [t_size, embed_dim] or [1+t_size, embed_dim] (w/ or w/o cls_token)
91
+ """
92
+ grid_t = np.arange(t_size, dtype=np.float32)
93
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)
94
+ if cls_token:
95
+ pos_embed = np.concatenate(
96
+ [np.zeros([1, embed_dim]), pos_embed], axis=0
97
+ )
98
+ return pos_embed
99
+
100
+
101
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
102
+ assert embed_dim % 2 == 0
103
+
104
+ # use half of dimensions to encode grid_h
105
+ emb_h = get_1d_sincos_pos_embed_from_grid(
106
+ embed_dim // 2, grid[0]
107
+ ) # (H*W, D/2)
108
+ emb_w = get_1d_sincos_pos_embed_from_grid(
109
+ embed_dim // 2, grid[1]
110
+ ) # (H*W, D/2)
111
+
112
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
113
+ return emb
114
+
115
+
116
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
117
+ """
118
+ embed_dim: output dimension for each position
119
+ pos: a list of positions to be encoded: size (M,)
120
+ out: (M, D)
121
+ """
122
+ assert embed_dim % 2 == 0
123
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
124
+ omega /= embed_dim / 2.0
125
+ omega = 1.0 / 10000**omega # (D/2,)
126
+
127
+ pos = pos.reshape(-1) # (M,)
128
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
129
+
130
+ emb_sin = np.sin(out) # (M, D/2)
131
+ emb_cos = np.cos(out) # (M, D/2)
132
+
133
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
134
+ return emb
135
+
136
+
137
+ def interpolate_pos_embed(checkpoint_model, model, orig_t_size=4, pos_name='vision_encoder.pos_embed'):
138
+ if pos_name in checkpoint_model:
139
+ pos_embed_checkpoint = checkpoint_model[pos_name]
140
+ embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
141
+ num_patches = model.patch_embed.num_patches #
142
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
143
+
144
+ # we use 4 frames for pretraining
145
+ new_t_size = model.T
146
+ # height (== width) for the checkpoint position embedding
147
+ orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
148
+ # height (== width) for the new position embedding
149
+ new_size = int((num_patches // (new_t_size))** 0.5)
150
+
151
+ # class_token and dist_token are kept unchanged
152
+ if orig_t_size != new_t_size:
153
+ logger.info(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
154
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
155
+ # only the position tokens are interpolated
156
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
157
+ # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
158
+ pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
159
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
160
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
161
+ pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
162
+ pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
163
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
164
+ checkpoint_model[pos_name] = new_pos_embed
165
+ pos_embed_checkpoint = new_pos_embed
166
+
167
+ # class_token and dist_token are kept unchanged
168
+ if orig_size != new_size:
169
+ logger.info(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
170
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
171
+ # only the position tokens are interpolated
172
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
173
+ # B, L, C -> BT, H, W, C -> BT, C, H, W
174
+ pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
175
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
176
+ pos_tokens = torch.nn.functional.interpolate(
177
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
178
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
179
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
180
+ pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
181
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
182
+ checkpoint_model[pos_name] = new_pos_embed
183
+
184
+
185
+ def interpolate_pos_embed_internvideo2(checkpoint_model, model, orig_t_size = 8):
186
+ # interpolate position embedding
187
+ for pos_name in ['pos_embed', 'clip_pos_embed']:
188
+ if pos_name in checkpoint_model:
189
+ pos_embed_checkpoint = checkpoint_model[pos_name]
190
+ embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
191
+ num_patches = model.patch_embed.num_patches #
192
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
193
+
194
+ # we use 8 frames for pretraining
195
+ # new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size
196
+ new_t_size = model.num_frames // model.tubelet_size
197
+ # height (== width) for the checkpoint position embedding
198
+ orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
199
+ # height (== width) for the new position embedding
200
+ new_size = int((num_patches // (new_t_size))** 0.5)
201
+
202
+ # class_token and dist_token are kept unchanged
203
+ if orig_t_size != new_t_size:
204
+ logger.info(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
205
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
206
+ # only the position tokens are interpolated
207
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
208
+ # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
209
+ pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
210
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
211
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
212
+ pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
213
+ pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
214
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
215
+ checkpoint_model[pos_name] = new_pos_embed
216
+ pos_embed_checkpoint = new_pos_embed
217
+
218
+ # class_token and dist_token are kept unchanged
219
+ if orig_size != new_size:
220
+ logger.info(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
221
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
222
+ # only the position tokens are interpolated
223
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
224
+ # B, L, C -> BT, H, W, C -> BT, C, H, W
225
+ pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
226
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
227
+ pos_tokens = torch.nn.functional.interpolate(
228
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
229
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
230
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
231
+ pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
232
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
233
+ checkpoint_model[pos_name] = new_pos_embed
234
+
235
+ if 'pos_embed_spatial' in checkpoint_model or 'pos_embed_temporal' in checkpoint_model:
236
+ raise NotImplementedError
237
+
238
+
239
+ def interpolate_pos_embed_internvideo2_new(checkpoint_model, model, orig_t_size = 8):
240
+ pos_names = []
241
+ for k in checkpoint_model.keys():
242
+ if ('pos_embed' in k or 'clip_pos_embed' in k) and 'img_pos_embed' not in k:
243
+ pos_names.append(k)
244
+
245
+ logger.info(f"pos names list for interpolating: {pos_names}")
246
+
247
+ assert len(pos_names) > 0, checkpoint_model.keys()
248
+
249
+ if 'pos_embed_spatial' in checkpoint_model.keys() or 'pos_embed_temporal' in checkpoint_model.keys():
250
+ raise NotImplementedError
251
+
252
+ # interpolate position embedding
253
+ for pos_name in pos_names:
254
+
255
+ pos_embed_checkpoint = checkpoint_model[pos_name]
256
+ embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
257
+ num_patches = model.patch_embed.num_patches #
258
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
259
+
260
+ # we use 8 frames for pretraining
261
+ # new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size
262
+ new_t_size = model.num_frames // model.tubelet_size
263
+ # height (== width) for the checkpoint position embedding
264
+ orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
265
+ # height (== width) for the new position embedding
266
+ new_size = int((num_patches // (new_t_size))** 0.5)
267
+
268
+ # class_token and dist_token are kept unchanged
269
+ if orig_t_size != new_t_size:
270
+ logger.info(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
271
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
272
+ # only the position tokens are interpolated
273
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
274
+ # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
275
+ pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
276
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
277
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
278
+ pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
279
+ pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
280
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
281
+ checkpoint_model[pos_name] = new_pos_embed
282
+ pos_embed_checkpoint = new_pos_embed
283
+
284
+ # class_token and dist_token are kept unchanged
285
+ if orig_size != new_size:
286
+ logger.info(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
287
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
288
+ # only the position tokens are interpolated
289
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
290
+ # B, L, C -> BT, H, W, C -> BT, C, H, W
291
+ pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
292
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
293
+ pos_tokens = torch.nn.functional.interpolate(
294
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
295
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
296
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
297
+ pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
298
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
299
+ checkpoint_model[pos_name] = new_pos_embed