BAAI
/

3v324v23 commited on
Commit
a5ce455
·
1 Parent(s): 8e7f18b
Files changed (12) hide show
  1. attention_temporal_videoae.py +1314 -0
  2. base_encoder.py +68 -0
  3. builder.py +17 -0
  4. llava_arch.py +76 -52
  5. llava_qwen.py +44 -24
  6. mm_utils.py +18 -14
  7. modeling_qwen2.py +4 -1
  8. sae.py +45 -0
  9. sae_utils.py +302 -0
  10. siglip_encoder.py +154 -0
  11. utils.py +166 -0
  12. utils_encoder.py +296 -0
attention_temporal_videoae.py ADDED
@@ -0,0 +1,1314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch as th
5
+ import torch.nn.functional as F
6
+ from torch import nn, einsum
7
+ from einops import rearrange, repeat
8
+ from typing import Optional, Any
9
+
10
+ try:
11
+ import xformers
12
+ import xformers.ops
13
+
14
+ XFORMERS_IS_AVAILBLE = True
15
+ except:
16
+ XFORMERS_IS_AVAILBLE = False
17
+
18
+ from .utils_encoder import (
19
+ conv_nd,
20
+ zero_module,
21
+ normalization,
22
+ )
23
+
24
+
25
+ def exists(val):
26
+ return val is not None
27
+
28
+
29
+ def uniq(arr):
30
+ return {el: True for el in arr}.keys()
31
+
32
+
33
+ def default(val, d):
34
+ if exists(val):
35
+ return val
36
+ return d() if isfunction(d) else d
37
+
38
+
39
+ def max_neg_value(t):
40
+ return -torch.finfo(t.dtype).max
41
+
42
+
43
+ def init_(tensor):
44
+ dim = tensor.shape[-1]
45
+ std = 1 / math.sqrt(dim)
46
+ tensor.uniform_(-std, std)
47
+ return tensor
48
+
49
+
50
+ # feedforward
51
+ class GEGLU(nn.Module):
52
+ def __init__(self, dim_in, dim_out):
53
+ super().__init__()
54
+ self.proj = nn.Linear(dim_in, dim_out * 2)
55
+
56
+ def forward(self, x):
57
+ x, gate = self.proj(x).chunk(2, dim=-1)
58
+ return x * F.gelu(gate)
59
+
60
+
61
+ class FeedForward(nn.Module):
62
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
63
+ super().__init__()
64
+ inner_dim = int(dim * mult)
65
+ dim_out = default(dim_out, dim)
66
+ project_in = (
67
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
68
+ if not glu
69
+ else GEGLU(dim, inner_dim)
70
+ )
71
+
72
+ self.net = nn.Sequential(
73
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
74
+ )
75
+
76
+ def forward(self, x):
77
+ return self.net(x)
78
+
79
+
80
+ def zero_module(module):
81
+ """
82
+ Zero out the parameters of a module and return it.
83
+ """
84
+ for p in module.parameters():
85
+ p.detach().zero_()
86
+ return module
87
+
88
+
89
+ def Normalize(in_channels, num_groups=32):
90
+ return torch.nn.GroupNorm(
91
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
92
+ )
93
+
94
+
95
+ # ---------------------------------------------------------------------------------------------------
96
+ class RelativePosition(nn.Module):
97
+ """https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py"""
98
+
99
+ def __init__(self, num_units, max_relative_position):
100
+ super().__init__()
101
+ self.num_units = num_units
102
+ self.max_relative_position = max_relative_position
103
+ self.embeddings_table = nn.Parameter(
104
+ th.Tensor(max_relative_position * 2 + 1, num_units)
105
+ )
106
+ nn.init.xavier_uniform_(self.embeddings_table)
107
+
108
+ def forward(self, length_q, length_k):
109
+ device = self.embeddings_table.device
110
+ range_vec_q = th.arange(length_q, device=device)
111
+ range_vec_k = th.arange(length_k, device=device)
112
+ distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
113
+ distance_mat_clipped = th.clamp(
114
+ distance_mat, -self.max_relative_position, self.max_relative_position
115
+ )
116
+ final_mat = distance_mat_clipped + self.max_relative_position
117
+ # final_mat = th.LongTensor(final_mat).to(self.embeddings_table.device)
118
+ # final_mat = th.tensor(final_mat, device=self.embeddings_table.device, dtype=torch.long)
119
+ final_mat = final_mat.long()
120
+ embeddings = self.embeddings_table[final_mat]
121
+ return embeddings
122
+
123
+
124
+ class TemporalCrossAttention(nn.Module):
125
+ def __init__(
126
+ self,
127
+ query_dim,
128
+ context_dim=None,
129
+ heads=8,
130
+ dim_head=64,
131
+ dropout=0.0,
132
+ temporal_length=None, # For relative positional representation and image-video joint training.
133
+ image_length=None, # For image-video joint training.
134
+ use_relative_position=False, # whether use relative positional representation in temporal attention.
135
+ img_video_joint_train=False, # For image-video joint training.
136
+ use_tempoal_causal_attn=False,
137
+ bidirectional_causal_attn=False,
138
+ tempoal_attn_type=None,
139
+ joint_train_mode="same_batch",
140
+ **kwargs,
141
+ ):
142
+ super().__init__()
143
+ inner_dim = dim_head * heads
144
+ context_dim = default(context_dim, query_dim)
145
+ self.context_dim = context_dim
146
+
147
+ self.scale = dim_head**-0.5
148
+ self.heads = heads
149
+ self.temporal_length = temporal_length
150
+ self.use_relative_position = use_relative_position
151
+ self.img_video_joint_train = img_video_joint_train
152
+ self.bidirectional_causal_attn = bidirectional_causal_attn
153
+ self.joint_train_mode = joint_train_mode
154
+ assert joint_train_mode in ["same_batch", "diff_batch"]
155
+ self.tempoal_attn_type = tempoal_attn_type
156
+
157
+ if bidirectional_causal_attn:
158
+ assert use_tempoal_causal_attn
159
+ if tempoal_attn_type:
160
+ assert tempoal_attn_type in ["sparse_causal", "sparse_causal_first"]
161
+ assert not use_tempoal_causal_attn
162
+ assert not (
163
+ img_video_joint_train and (self.joint_train_mode == "same_batch")
164
+ )
165
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
166
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
167
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
168
+
169
+ assert not (
170
+ img_video_joint_train
171
+ and (self.joint_train_mode == "same_batch")
172
+ and use_tempoal_causal_attn
173
+ )
174
+ if img_video_joint_train:
175
+ if self.joint_train_mode == "same_batch":
176
+ mask = torch.ones(
177
+ [1, temporal_length + image_length, temporal_length + image_length]
178
+ )
179
+ # mask[:, image_length:, :] = 0
180
+ # mask[:, :, image_length:] = 0
181
+ mask[:, temporal_length:, :] = 0
182
+ mask[:, :, temporal_length:] = 0
183
+ self.mask = mask
184
+ else:
185
+ self.mask = None
186
+ elif use_tempoal_causal_attn:
187
+ # normal causal attn
188
+ self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length]))
189
+ elif tempoal_attn_type == "sparse_causal":
190
+ # all frames interact with only the `prev` & self frame
191
+ mask1 = torch.tril(
192
+ torch.ones([1, temporal_length, temporal_length])
193
+ ).bool() # true indicates keeping
194
+ mask2 = torch.zeros(
195
+ [1, temporal_length, temporal_length]
196
+ ) # initialize to same shape with mask1
197
+ mask2[:, 2:temporal_length, : temporal_length - 2] = torch.tril(
198
+ torch.ones([1, temporal_length - 2, temporal_length - 2])
199
+ )
200
+ mask2 = (1 - mask2).bool() # false indicates masking
201
+ self.mask = mask1 & mask2
202
+ elif tempoal_attn_type == "sparse_causal_first":
203
+ # all frames interact with only the `first` & self frame
204
+ mask1 = torch.tril(
205
+ torch.ones([1, temporal_length, temporal_length])
206
+ ).bool() # true indicates keeping
207
+ mask2 = torch.zeros([1, temporal_length, temporal_length])
208
+ mask2[:, 2:temporal_length, 1 : temporal_length - 1] = torch.tril(
209
+ torch.ones([1, temporal_length - 2, temporal_length - 2])
210
+ )
211
+ mask2 = (1 - mask2).bool() # false indicates masking
212
+ self.mask = mask1 & mask2
213
+ else:
214
+ self.mask = None
215
+
216
+ if use_relative_position:
217
+ assert temporal_length is not None
218
+ self.relative_position_k = RelativePosition(
219
+ num_units=dim_head, max_relative_position=temporal_length
220
+ )
221
+ self.relative_position_v = RelativePosition(
222
+ num_units=dim_head, max_relative_position=temporal_length
223
+ )
224
+
225
+ self.to_out = nn.Sequential(
226
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
227
+ )
228
+
229
+ nn.init.constant_(self.to_q.weight, 0)
230
+ nn.init.constant_(self.to_k.weight, 0)
231
+ nn.init.constant_(self.to_v.weight, 0)
232
+ nn.init.constant_(self.to_out[0].weight, 0)
233
+ nn.init.constant_(self.to_out[0].bias, 0)
234
+
235
+ def forward(self, x, context=None, mask=None):
236
+ # if context is None:
237
+ # print(f'[Temp Attn] x={x.shape},context=None')
238
+ # else:
239
+ # print(f'[Temp Attn] x={x.shape},context={context.shape}')
240
+
241
+ nh = self.heads
242
+ out = x
243
+ q = self.to_q(out)
244
+ # if context is not None:
245
+ # print(f'temporal context 1 ={context.shape}')
246
+ # print(f'x={x.shape}')
247
+ context = default(context, x)
248
+ # print(f'temporal context 2 ={context.shape}')
249
+ k = self.to_k(context)
250
+ v = self.to_v(context)
251
+ # print(f'q ={q.shape},k={k.shape}')
252
+
253
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=nh), (q, k, v))
254
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
255
+
256
+ if self.use_relative_position:
257
+ len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
258
+ k2 = self.relative_position_k(len_q, len_k)
259
+ sim2 = einsum("b t d, t s d -> b t s", q, k2) * self.scale # TODO check
260
+ sim += sim2
261
+ # print('mask',mask)
262
+ if exists(self.mask):
263
+ if mask is None:
264
+ mask = self.mask.to(sim.device)
265
+ else:
266
+ mask = self.mask.to(sim.device).bool() & mask # .to(sim.device)
267
+ else:
268
+ mask = mask
269
+ # if self.img_video_joint_train:
270
+ # # process mask (make mask same shape with sim)
271
+ # c, h, w = mask.shape
272
+ # c, t, s = sim.shape
273
+ # # assert(h == w and t == s),f"mask={mask.shape}, sim={sim.shape}, h={h}, w={w}, t={t}, s={s}"
274
+
275
+ # if h > t:
276
+ # mask = mask[:, :t, :]
277
+ # elif h < t: # pad zeros to mask (no attention) only initial mask =1 area compute weights
278
+ # mask_ = torch.zeros([c,t,w]).to(mask.device)
279
+ # mask_[:, :h, :] = mask
280
+ # mask = mask_
281
+ # c, h, w = mask.shape
282
+ # if w > s:
283
+ # mask = mask[:, :, :s]
284
+ # elif w < s: # pad zeros to mask
285
+ # mask_ = torch.zeros([c,h,s]).to(mask.device)
286
+ # mask_[:, :, :w] = mask
287
+ # mask = mask_
288
+
289
+ # max_neg_value = -torch.finfo(sim.dtype).max
290
+ # sim = sim.float().masked_fill(mask == 0, max_neg_value)
291
+ if mask is not None:
292
+ max_neg_value = -1e9
293
+ sim = sim + (1 - mask.float()) * max_neg_value # 1=masking,0=no masking
294
+ # print('sim after masking: ', sim)
295
+
296
+ # if torch.isnan(sim).any() or torch.isinf(sim).any() or (not sim.any()):
297
+ # print(f'sim [after masking], isnan={torch.isnan(sim).any()}, isinf={torch.isinf(sim).any()}, allzero={not sim.any()}')
298
+
299
+ attn = sim.softmax(dim=-1)
300
+ # print('attn after softmax: ', attn)
301
+ # if torch.isnan(attn).any() or torch.isinf(attn).any() or (not attn.any()):
302
+ # print(f'attn [after softmax], isnan={torch.isnan(attn).any()}, isinf={torch.isinf(attn).any()}, allzero={not attn.any()}')
303
+
304
+ # attn = torch.where(torch.isnan(attn), torch.full_like(attn,0), attn)
305
+ # if torch.isinf(attn.detach()).any():
306
+ # import pdb;pdb.set_trace()
307
+ # if torch.isnan(attn.detach()).any():
308
+ # import pdb;pdb.set_trace()
309
+ out = einsum("b i j, b j d -> b i d", attn, v)
310
+
311
+ if self.bidirectional_causal_attn:
312
+ mask_reverse = torch.triu(
313
+ torch.ones(
314
+ [1, self.temporal_length, self.temporal_length], device=sim.device
315
+ )
316
+ )
317
+ sim_reverse = sim.float().masked_fill(mask_reverse == 0, max_neg_value)
318
+ attn_reverse = sim_reverse.softmax(dim=-1)
319
+ out_reverse = einsum("b i j, b j d -> b i d", attn_reverse, v)
320
+ out += out_reverse
321
+
322
+ if self.use_relative_position:
323
+ v2 = self.relative_position_v(len_q, len_v)
324
+ out2 = einsum("b t s, t s d -> b t d", attn, v2) # TODO check
325
+ out += out2 # TODO check:先add还是先merge head?先计算rpr,on split head之后的数据,然后再merge。
326
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=nh) # merge head
327
+ return self.to_out(out)
328
+
329
+
330
+ # ---------------------------------------------------------------------------------------------------
331
+
332
+
333
+ class SpatialSelfAttention(nn.Module):
334
+ def __init__(self, in_channels):
335
+ super().__init__()
336
+ self.in_channels = in_channels
337
+
338
+ self.norm = Normalize(in_channels)
339
+ self.q = torch.nn.Conv2d(
340
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
341
+ )
342
+ self.k = torch.nn.Conv2d(
343
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
344
+ )
345
+ self.v = torch.nn.Conv2d(
346
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
347
+ )
348
+ self.proj_out = torch.nn.Conv2d(
349
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
350
+ )
351
+
352
+ def forward(self, x):
353
+ h_ = x
354
+ h_ = self.norm(h_)
355
+ q = self.q(h_)
356
+ k = self.k(h_)
357
+ v = self.v(h_)
358
+
359
+ # compute attention
360
+ b, c, h, w = q.shape
361
+ q = rearrange(q, "b c h w -> b (h w) c")
362
+ k = rearrange(k, "b c h w -> b c (h w)")
363
+ w_ = torch.einsum("bij,bjk->bik", q, k)
364
+
365
+ w_ = w_ * (int(c) ** (-0.5))
366
+ w_ = torch.nn.functional.softmax(w_, dim=2)
367
+
368
+ # attend to values
369
+ v = rearrange(v, "b c h w -> b c (h w)")
370
+ w_ = rearrange(w_, "b i j -> b j i")
371
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
372
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
373
+ h_ = self.proj_out(h_)
374
+
375
+ return x + h_
376
+
377
+
378
+ class CrossAttention(nn.Module):
379
+ def __init__(
380
+ self,
381
+ query_dim,
382
+ context_dim=None,
383
+ heads=8,
384
+ dim_head=64,
385
+ dropout=0.0,
386
+ sa_shared_kv=False,
387
+ shared_type="only_first",
388
+ **kwargs,
389
+ ):
390
+ super().__init__()
391
+ inner_dim = dim_head * heads
392
+ context_dim = default(context_dim, query_dim)
393
+ self.sa_shared_kv = sa_shared_kv
394
+ assert shared_type in [
395
+ "only_first",
396
+ "all_frames",
397
+ "first_and_prev",
398
+ "only_prev",
399
+ "full",
400
+ "causal",
401
+ "full_qkv",
402
+ ]
403
+ self.shared_type = shared_type
404
+
405
+ self.scale = dim_head**-0.5
406
+ self.heads = heads
407
+ self.dim_head = dim_head
408
+
409
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
410
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
411
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
412
+
413
+ self.to_out = nn.Sequential(
414
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
415
+ )
416
+ self.attention_op: Optional[Any] = None
417
+
418
+ def forward(self, x, context=None, mask=None):
419
+ h = self.heads
420
+ b = x.shape[0]
421
+
422
+ q = self.to_q(x)
423
+ context = default(context, x)
424
+ k = self.to_k(context)
425
+ v = self.to_v(context)
426
+ if self.sa_shared_kv:
427
+ if self.shared_type == "only_first":
428
+ k, v = map(
429
+ lambda xx: rearrange(xx[0].unsqueeze(0), "b n c -> (b n) c")
430
+ .unsqueeze(0)
431
+ .repeat(b, 1, 1),
432
+ (k, v),
433
+ )
434
+ else:
435
+ raise NotImplementedError
436
+
437
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
438
+
439
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
440
+
441
+ if exists(mask):
442
+ mask = rearrange(mask, "b ... -> b (...)")
443
+ max_neg_value = -torch.finfo(sim.dtype).max
444
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
445
+ sim.masked_fill_(~mask, max_neg_value)
446
+
447
+ # attention, what we cannot get enough of
448
+ attn = sim.softmax(dim=-1)
449
+
450
+ out = einsum("b i j, b j d -> b i d", attn, v)
451
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
452
+ return self.to_out(out)
453
+
454
+ def efficient_forward(self, x, context=None, mask=None):
455
+ q = self.to_q(x)
456
+ context = default(context, x)
457
+ k = self.to_k(context)
458
+ v = self.to_v(context)
459
+
460
+ b, _, _ = q.shape
461
+ q, k, v = map(
462
+ lambda t: t.unsqueeze(3)
463
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
464
+ .permute(0, 2, 1, 3)
465
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
466
+ .contiguous(),
467
+ (q, k, v),
468
+ )
469
+ # actually compute the attention, what we cannot get enough of
470
+ out = xformers.ops.memory_efficient_attention(
471
+ q, k, v, attn_bias=None, op=self.attention_op
472
+ )
473
+
474
+ if exists(mask):
475
+ raise NotImplementedError
476
+ out = (
477
+ out.unsqueeze(0)
478
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
479
+ .permute(0, 2, 1, 3)
480
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
481
+ )
482
+ return self.to_out(out)
483
+
484
+
485
+ class VideoSpatialCrossAttention(CrossAttention):
486
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0):
487
+ super().__init__(query_dim, context_dim, heads, dim_head, dropout)
488
+
489
+ def forward(self, x, context=None, mask=None):
490
+ b, c, t, h, w = x.shape
491
+ if context is not None:
492
+ context = context.repeat(t, 1, 1)
493
+ x = super.forward(spatial_attn_reshape(x), context=context) + x
494
+ return spatial_attn_reshape_back(x, b, h)
495
+
496
+
497
+ # class BasicTransformerBlockST(nn.Module):
498
+ # def __init__(
499
+ # self,
500
+ # # Spatial Stuff
501
+ # dim,
502
+ # n_heads,
503
+ # d_head,
504
+ # dropout=0.0,
505
+ # context_dim=None,
506
+ # gated_ff=True,
507
+ # checkpoint=True,
508
+ # # Temporal Stuff
509
+ # temporal_length=None,
510
+ # image_length=None,
511
+ # use_relative_position=True,
512
+ # img_video_joint_train=False,
513
+ # cross_attn_on_tempoal=False,
514
+ # temporal_crossattn_type="selfattn",
515
+ # order="stst",
516
+ # temporalcrossfirst=False,
517
+ # temporal_context_dim=None,
518
+ # split_stcontext=False,
519
+ # local_spatial_temporal_attn=False,
520
+ # window_size=2,
521
+ # random_t=False,
522
+ # **kwargs,
523
+ # ):
524
+ # super().__init__()
525
+ # # Self attention
526
+ # self.attn1 = CrossAttention(
527
+ # query_dim=dim,
528
+ # heads=n_heads,
529
+ # dim_head=d_head,
530
+ # dropout=dropout,
531
+ # **kwargs,
532
+ # )
533
+ # self.attn2 = CrossAttention(
534
+ # query_dim=dim,
535
+ # context_dim=context_dim,
536
+ # heads=n_heads,
537
+ # dim_head=d_head,
538
+ # dropout=dropout,
539
+ # **kwargs,
540
+ # )
541
+ # if XFORMERS_IS_AVAILBLE:
542
+ # self.attn1.forward = self.attn1.efficient_forward
543
+ # self.attn2.forward = self.attn2.efficient_forward
544
+
545
+ # self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
546
+ # # cross attention if context is not None
547
+
548
+ # self.norm1 = nn.LayerNorm(dim)
549
+ # self.norm2 = nn.LayerNorm(dim)
550
+ # self.norm3 = nn.LayerNorm(dim)
551
+ # self.checkpoint = checkpoint
552
+ # self.order = order
553
+ # assert self.order in ["stst", "sstt", "st_parallel"]
554
+ # self.temporalcrossfirst = temporalcrossfirst
555
+ # self.split_stcontext = split_stcontext
556
+ # self.local_spatial_temporal_attn = local_spatial_temporal_attn
557
+ # if self.local_spatial_temporal_attn:
558
+ # assert self.order == "stst"
559
+ # assert self.order == "stst"
560
+ # self.window_size = window_size
561
+ # if not split_stcontext:
562
+ # temporal_context_dim = context_dim
563
+ # # Temporal attention
564
+ # assert temporal_crossattn_type in ["selfattn", "crossattn", "skip"]
565
+ # self.temporal_crossattn_type = temporal_crossattn_type
566
+ # self.attn1_tmp = TemporalCrossAttention(
567
+ # query_dim=dim,
568
+ # heads=n_heads,
569
+ # dim_head=d_head,
570
+ # dropout=dropout,
571
+ # temporal_length=temporal_length,
572
+ # image_length=image_length,
573
+ # use_relative_position=use_relative_position,
574
+ # img_video_joint_train=img_video_joint_train,
575
+ # **kwargs,
576
+ # )
577
+ # self.attn2_tmp = TemporalCrossAttention(
578
+ # query_dim=dim,
579
+ # heads=n_heads,
580
+ # dim_head=d_head,
581
+ # dropout=dropout,
582
+ # # cross attn
583
+ # context_dim=(
584
+ # temporal_context_dim if temporal_crossattn_type == "crossattn" else None
585
+ # ),
586
+ # # temporal attn
587
+ # temporal_length=temporal_length,
588
+ # image_length=image_length,
589
+ # use_relative_position=use_relative_position,
590
+ # img_video_joint_train=img_video_joint_train,
591
+ # **kwargs,
592
+ # )
593
+ # self.norm4 = nn.LayerNorm(dim)
594
+ # self.norm5 = nn.LayerNorm(dim)
595
+ # self.random_t = random_t
596
+ # # self.norm1_tmp = nn.LayerNorm(dim)
597
+ # # self.norm2_tmp = nn.LayerNorm(dim)
598
+
599
+ # ##############################################################################################################################################
600
+ # def forward(
601
+ # self,
602
+ # x,
603
+ # context=None,
604
+ # temporal_context=None,
605
+ # no_temporal_attn=None,
606
+ # attn_mask=None,
607
+ # **kwargs,
608
+ # ):
609
+ # # print(f'no_temporal_attn={no_temporal_attn}')
610
+
611
+ # if not self.split_stcontext:
612
+ # # st cross attention use the same context vector
613
+ # temporal_context = context.detach().clone()
614
+
615
+ # if context is None and temporal_context is None:
616
+ # # self-attention models
617
+ # if no_temporal_attn:
618
+ # raise NotImplementedError
619
+ # return checkpoint(
620
+ # self._forward_nocontext, (x), self.parameters(), self.checkpoint
621
+ # )
622
+ # else:
623
+ # # cross-attention models
624
+ # if no_temporal_attn:
625
+ # forward_func = self._forward_no_temporal_attn
626
+ # else:
627
+ # forward_func = self._forward
628
+ # inputs = (
629
+ # (x, context, temporal_context)
630
+ # if temporal_context is not None
631
+ # else (x, context)
632
+ # )
633
+ # return checkpoint(forward_func, inputs, self.parameters(), self.checkpoint)
634
+ # # if attn_mask is not None:
635
+ # # return checkpoint(self._forward, (x, context, temporal_context, attn_mask), self.parameters(), self.checkpoint)
636
+ # # return checkpoint(self._forward, (x, context, temporal_context), self.parameters(), self.checkpoint)
637
+
638
+ # def _forward(
639
+ # self,
640
+ # x,
641
+ # context=None,
642
+ # temporal_context=None,
643
+ # mask=None,
644
+ # no_temporal_attn=None,
645
+ # ):
646
+ # assert x.dim() == 5, f"x shape = {x.shape}"
647
+ # b, c, t, h, w = x.shape
648
+
649
+ # if self.order in ["stst", "sstt"]:
650
+ # x = self._st_cross_attn(
651
+ # x,
652
+ # context,
653
+ # temporal_context=temporal_context,
654
+ # order=self.order,
655
+ # mask=mask,
656
+ # ) # no_temporal_attn=no_temporal_attn,
657
+ # elif self.order == "st_parallel":
658
+ # x = self._st_cross_attn_parallel(
659
+ # x,
660
+ # context,
661
+ # temporal_context=temporal_context,
662
+ # order=self.order,
663
+ # ) # no_temporal_attn=no_temporal_attn,
664
+ # else:
665
+ # raise NotImplementedError
666
+
667
+ # x = self.ff(self.norm3(x)) + x
668
+ # if (no_temporal_attn is None) or (not no_temporal_attn):
669
+ # x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
670
+ # elif no_temporal_attn:
671
+ # x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
672
+ # return x
673
+
674
+ # def _forward_no_temporal_attn(
675
+ # self,
676
+ # x,
677
+ # context=None,
678
+ # temporal_context=None,
679
+ # ):
680
+ # # temporary implementation :(
681
+ # # because checkpoint does not support non-tensor inputs currently.
682
+ # assert x.dim() == 5, f"x shape = {x.shape}"
683
+ # b, c, t, h, w = x.shape
684
+
685
+ # if self.order in ["stst", "sstt"]:
686
+ # # x = self._st_cross_attn(x, context, temporal_context=temporal_context, order=self.order, no_temporal_attn=True,)
687
+ # # mask = torch.zeros([1, t, t], device=x.device).bool() if context is None else torch.zeros([1, context.shape[1], t], device=x.device).bool()
688
+ # mask = torch.zeros([1, t, t], device=x.device).bool()
689
+ # x = self._st_cross_attn(
690
+ # x,
691
+ # context,
692
+ # temporal_context=temporal_context,
693
+ # order=self.order,
694
+ # mask=mask,
695
+ # )
696
+ # elif self.order == "st_parallel":
697
+ # x = self._st_cross_attn_parallel(
698
+ # x,
699
+ # context,
700
+ # temporal_context=temporal_context,
701
+ # order=self.order,
702
+ # no_temporal_attn=True,
703
+ # )
704
+ # else:
705
+ # raise NotImplementedError
706
+
707
+ # x = self.ff(self.norm3(x)) + x
708
+ # x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
709
+ # # x = rearrange(x, '(b t) (h w) c -> b c t h w', b=b,h=h,w=w) # 3d -> 5d
710
+ # return x
711
+
712
+ # def _forward_nocontext(self, x, no_temporal_attn=None):
713
+ # assert x.dim() == 5, f"x shape = {x.shape}"
714
+ # b, c, t, h, w = x.shape
715
+
716
+ # if self.order in ["stst", "sstt"]:
717
+ # x = self._st_cross_attn(
718
+ # x, order=self.order, no_temporal_attn=no_temporal_attn
719
+ # )
720
+ # elif self.order == "st_parallel":
721
+ # x = self._st_cross_attn_parallel(
722
+ # x, order=self.order, no_temporal_attn=no_temporal_attn
723
+ # )
724
+ # else:
725
+ # raise NotImplementedError
726
+
727
+ # x = self.ff(self.norm3(x)) + x
728
+ # x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
729
+
730
+ # return x
731
+
732
+ # ##############################################################################################################################################
733
+
734
+ # def _st_cross_attn(
735
+ # self, x, context=None, temporal_context=None, order="stst", mask=None
736
+ # ): # no_temporal_attn=None,
737
+ # b, c, t, h, w = x.shape
738
+ # # if context is not None:
739
+ # # print(f'[_st_cross_attn input] x={x.shape}, context={context.shape}')
740
+ # # else:
741
+ # # print(f'[_st_cross_attn input] x={x.shape}')
742
+
743
+ # if order == "stst":
744
+ # # spatial self attention
745
+ # x = rearrange(x, "b c t h w -> (b t) (h w) c")
746
+ # # print(f'before attn1,x={x.shape}')
747
+
748
+ # x = self.attn1(self.norm1(x)) + x
749
+ # x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
750
+
751
+ # # temporal self attention
752
+ # # if (no_temporal_attn is None) or (not no_temporal_attn):
753
+ # if self.local_spatial_temporal_attn:
754
+ # x = local_spatial_temporal_attn_reshape(x, window_size=self.window_size)
755
+ # else:
756
+ # x = rearrange(x, "b c t h w -> (b h w) t c")
757
+ # x = self.attn1_tmp(self.norm4(x), mask=mask) + x
758
+
759
+ # if self.local_spatial_temporal_attn:
760
+ # x = local_spatial_temporal_attn_reshape_back(
761
+ # x, window_size=self.window_size, b=b, h=h, w=w, t=t
762
+ # )
763
+ # else:
764
+ # x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
765
+
766
+ # # spatial cross attention
767
+ # x = rearrange(x, "b c t h w -> (b t) (h w) c")
768
+ # # print(f'before attn2, x={x.shape}')
769
+ # # if context is not None:
770
+ # # print(f'[before attn2] context={context.shape}')
771
+ # if context is not None:
772
+ # if self.random_t:
773
+ # context_ = []
774
+ # for i in range(context.shape[0]):
775
+ # context_.append(context[i].unsqueeze(0).repeat(t, 1, 1))
776
+ # context_ = torch.cat(context_, dim=0)
777
+ # else:
778
+ # if context.shape[0] == t: # img captions no_temporal_attn or
779
+ # context_ = context
780
+ # else:
781
+ # # repeat conditions with t times
782
+ # context_ = []
783
+ # for i in range(context.shape[0]):
784
+ # context_.append(context[i].unsqueeze(0).repeat(t, 1, 1))
785
+ # context_ = torch.cat(context_, dim=0)
786
+ # else:
787
+ # context_ = None
788
+
789
+ # # if context_ is not None:
790
+ # # print(f'[before attn2] x={x.shape}, context_={context_.shape}')
791
+ # # else:
792
+ # # print(f'[before attn2] x={x.shape}')
793
+
794
+ # x = self.attn2(self.norm2(x), context=context_) + x
795
+
796
+ # # temporal cross attention
797
+ # # if (no_temporal_attn is None) or (not no_temporal_attn):
798
+ # x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
799
+ # x = rearrange(x, "b c t h w -> (b h w) t c")
800
+ # if self.temporal_crossattn_type == "crossattn":
801
+ # # tmporal cross attention
802
+ # if temporal_context is not None:
803
+ # # print(f'STATTN context={context.shape}, temporal_context={temporal_context.shape}')
804
+ # temporal_context = torch.cat(
805
+ # [context, temporal_context], dim=1
806
+ # ) # blc
807
+ # # print(f'STATTN after concat temporal_context={temporal_context.shape}')
808
+ # temporal_context = temporal_context.repeat(h * w, 1, 1)
809
+ # # print(f'after repeat temporal_context={temporal_context.shape}')
810
+ # else:
811
+ # temporal_context = context[0:1, ...].repeat(h * w, 1, 1)
812
+ # # print(f'STATTN after concat x={x.shape}')
813
+ # x = (
814
+ # self.attn2_tmp(self.norm5(x), context=temporal_context, mask=mask)
815
+ # + x
816
+ # )
817
+ # elif self.temporal_crossattn_type == "selfattn":
818
+ # # temporal self attention
819
+ # x = self.attn2_tmp(self.norm5(x), context=None, mask=mask) + x
820
+ # elif self.temporal_crossattn_type == "skip":
821
+ # # no temporal cross and self attention
822
+ # pass
823
+ # else:
824
+ # raise NotImplementedError
825
+
826
+ # elif order == "sstt":
827
+ # # spatial self attention
828
+ # x = rearrange(x, "b c t h w -> (b t) (h w) c")
829
+ # x = self.attn1(self.norm1(x)) + x
830
+
831
+ # # spatial cross attention
832
+ # context_ = context.repeat(t, 1, 1) if context is not None else None
833
+ # x = self.attn2(self.norm2(x), context=context_) + x
834
+ # x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
835
+
836
+ # if (no_temporal_attn is None) or (not no_temporal_attn):
837
+ # if self.temporalcrossfirst:
838
+ # # temporal cross attention
839
+ # if self.temporal_crossattn_type == "crossattn":
840
+ # # if temporal_context is not None:
841
+ # temporal_context = context.repeat(h * w, 1, 1)
842
+ # x = (
843
+ # self.attn2_tmp(
844
+ # self.norm5(x), context=temporal_context, mask=mask
845
+ # )
846
+ # + x
847
+ # )
848
+ # elif self.temporal_crossattn_type == "selfattn":
849
+ # x = self.attn2_tmp(self.norm5(x), context=None, mask=mask) + x
850
+ # elif self.temporal_crossattn_type == "skip":
851
+ # pass
852
+ # else:
853
+ # raise NotImplementedError
854
+ # # temporal self attention
855
+ # x = rearrange(x, "b c t h w -> (b h w) t c")
856
+ # x = self.attn1_tmp(self.norm4(x), mask=mask) + x
857
+ # else:
858
+ # # temporal self attention
859
+ # x = rearrange(x, "b c t h w -> (b h w) t c")
860
+ # x = self.attn1_tmp(self.norm4(x), mask=mask) + x
861
+ # # temporal cross attention
862
+ # if self.temporal_crossattn_type == "crossattn":
863
+ # if temporal_context is not None:
864
+ # temporal_context = context.repeat(h * w, 1, 1)
865
+ # x = (
866
+ # self.attn2_tmp(
867
+ # self.norm5(x), context=temporal_context, mask=mask
868
+ # )
869
+ # + x
870
+ # )
871
+ # elif self.temporal_crossattn_type == "selfattn":
872
+ # x = self.attn2_tmp(self.norm5(x), context=None, mask=mask) + x
873
+ # elif self.temporal_crossattn_type == "skip":
874
+ # pass
875
+ # else:
876
+ # raise NotImplementedError
877
+ # else:
878
+ # raise NotImplementedError
879
+
880
+ # return x
881
+
882
+ # def _st_cross_attn_parallel(
883
+ # self, x, context=None, temporal_context=None, order="sst", no_temporal_attn=None
884
+ # ):
885
+ # """order: x -> Self Attn -> Cross Attn -> attn_s
886
+ # x -> Temp Self Attn -> attn_t
887
+ # x' = x + attn_s + attn_t
888
+ # """
889
+ # if no_temporal_attn is not None:
890
+ # raise NotImplementedError
891
+
892
+ # B, C, T, H, W = x.shape
893
+ # # spatial self attention
894
+ # h = x
895
+ # h = rearrange(h, "b c t h w -> (b t) (h w) c")
896
+ # h = self.attn1(self.norm1(h)) + h
897
+ # # spatial cross
898
+ # # context_ = context.repeat(T, 1, 1) if context is not None else None
899
+ # if context is not None:
900
+ # context_ = []
901
+ # for i in range(context.shape[0]):
902
+ # context_.append(context[i].unsqueeze(0).repeat(T, 1, 1))
903
+ # context_ = torch.cat(context_, dim=0)
904
+ # else:
905
+ # context_ = None
906
+
907
+ # h = self.attn2(self.norm2(h), context=context_) + h
908
+ # h = rearrange(h, "(b t) (h w) c -> b c t h w", b=B, h=H)
909
+
910
+ # # temporal self
911
+ # h2 = x
912
+ # h2 = rearrange(h2, "b c t h w -> (b h w) t c")
913
+ # h2 = self.attn1_tmp(self.norm4(h2)) # + h2
914
+ # h2 = rearrange(h2, "(b h w) t c -> b c t h w", b=B, h=H, w=W)
915
+ # out = h + h2
916
+ # return rearrange(out, "b c t h w -> (b h w) t c")
917
+
918
+ ##############################################################################################################################################
919
+
920
+
921
+ def spatial_attn_reshape(x):
922
+ return rearrange(x, "b c t h w -> (b t) (h w) c")
923
+
924
+
925
+ def spatial_attn_reshape_back(x, b, h):
926
+ return rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
927
+
928
+
929
+ def temporal_attn_reshape(x):
930
+ return rearrange(x, "b c t h w -> (b h w) t c")
931
+
932
+
933
+ def temporal_attn_reshape_back(x, b, h, w):
934
+ return rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w)
935
+
936
+
937
+ def local_spatial_temporal_attn_reshape(x, window_size):
938
+ B, C, T, H, W = x.shape
939
+ NH = H // window_size
940
+ NW = W // window_size
941
+ # x = x.view(B, C, T, NH, window_size, NW, window_size)
942
+ # tokens = x.permute(0, 1, 2, 3, 5, 4, 6).contiguous()
943
+ # tokens = tokens.view(-1, window_size, window_size, C)
944
+ x = rearrange(
945
+ x,
946
+ "b c t (nh wh) (nw ww) -> b c t nh wh nw ww",
947
+ nh=NH,
948
+ nw=NW,
949
+ wh=window_size,
950
+ ww=window_size,
951
+ ).contiguous() # # B, C, T, NH, NW, window_size, window_size
952
+ x = rearrange(
953
+ x, "b c t nh wh nw ww -> (b nh nw) (t wh ww) c"
954
+ ) # (B, NH, NW) (T, window_size, window_size) C
955
+ return x
956
+
957
+
958
+ def local_spatial_temporal_attn_reshape_back(x, window_size, b, h, w, t):
959
+ B, L, C = x.shape
960
+ NH = h // window_size
961
+ NW = w // window_size
962
+ x = rearrange(
963
+ x,
964
+ "(b nh nw) (t wh ww) c -> b c t nh wh nw ww",
965
+ b=b,
966
+ nh=NH,
967
+ nw=NW,
968
+ t=t,
969
+ wh=window_size,
970
+ ww=window_size,
971
+ )
972
+ x = rearrange(x, "b c t nh wh nw ww -> b c t (nh wh) (nw ww)")
973
+ return x
974
+
975
+
976
+ class SpatialTemporalTransformer(nn.Module):
977
+ """
978
+ Transformer block for video-like data (5D tensor).
979
+ First, project the input (aka embedding) with NO reshape.
980
+ Then apply standard transformer action.
981
+ The 5D -> 3D reshape operation will be done in the specific attention module.
982
+ """
983
+
984
+ def __init__(
985
+ self,
986
+ in_channels,
987
+ n_heads,
988
+ d_head,
989
+ depth=1,
990
+ dropout=0.0,
991
+ context_dim=None,
992
+ # Temporal stuff
993
+ temporal_length=None,
994
+ image_length=None,
995
+ use_relative_position=True,
996
+ img_video_joint_train=False,
997
+ cross_attn_on_tempoal=False,
998
+ temporal_crossattn_type="selfattn",
999
+ order="stst",
1000
+ temporalcrossfirst=False,
1001
+ split_stcontext=False,
1002
+ temporal_context_dim=None,
1003
+ **kwargs,
1004
+ ):
1005
+ super().__init__()
1006
+
1007
+ self.in_channels = in_channels
1008
+ inner_dim = n_heads * d_head
1009
+
1010
+ self.norm = Normalize(in_channels)
1011
+ self.proj_in = nn.Conv3d(
1012
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
1013
+ )
1014
+
1015
+ self.transformer_blocks = nn.ModuleList(
1016
+ [
1017
+ BasicTransformerBlockST(
1018
+ inner_dim,
1019
+ n_heads,
1020
+ d_head,
1021
+ dropout=dropout,
1022
+ # cross attn
1023
+ context_dim=context_dim,
1024
+ # temporal attn
1025
+ temporal_length=temporal_length,
1026
+ image_length=image_length,
1027
+ use_relative_position=use_relative_position,
1028
+ img_video_joint_train=img_video_joint_train,
1029
+ temporal_crossattn_type=temporal_crossattn_type,
1030
+ order=order,
1031
+ temporalcrossfirst=temporalcrossfirst,
1032
+ split_stcontext=split_stcontext,
1033
+ temporal_context_dim=temporal_context_dim,
1034
+ **kwargs,
1035
+ )
1036
+ for d in range(depth)
1037
+ ]
1038
+ )
1039
+
1040
+ self.proj_out = zero_module(
1041
+ nn.Conv3d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
1042
+ )
1043
+
1044
+ def forward(self, x, context=None, temporal_context=None, **kwargs):
1045
+ # note: if no context is given, cross-attention defaults to self-attention
1046
+ assert x.dim() == 5, f"x shape = {x.shape}"
1047
+ b, c, t, h, w = x.shape
1048
+ x_in = x
1049
+
1050
+ x = self.norm(x)
1051
+ x = self.proj_in(x)
1052
+
1053
+ for block in self.transformer_blocks:
1054
+ x = block(x, context=context, temporal_context=temporal_context, **kwargs)
1055
+
1056
+ x = self.proj_out(x)
1057
+ return x + x_in
1058
+
1059
+
1060
+ # ---------------------------------------------------------------------------------------------------
1061
+
1062
+
1063
+ class STAttentionBlock2(nn.Module):
1064
+ def __init__(
1065
+ self,
1066
+ channels,
1067
+ num_heads=1,
1068
+ num_head_channels=-1,
1069
+ use_checkpoint=False, # not used, only used in ResBlock
1070
+ use_new_attention_order=False, # QKVAttention or QKVAttentionLegacy
1071
+ temporal_length=16, # used in relative positional representation.
1072
+ image_length=8, # used for image-video joint training.
1073
+ use_relative_position=False, # whether use relative positional representation in temporal attention.
1074
+ img_video_joint_train=False,
1075
+ # norm_type="groupnorm",
1076
+ attn_norm_type="group",
1077
+ use_tempoal_causal_attn=False,
1078
+ ):
1079
+ """
1080
+ version 1: guided_diffusion implemented version
1081
+ version 2: remove args input argument
1082
+ """
1083
+ super().__init__()
1084
+
1085
+ if num_head_channels == -1:
1086
+ self.num_heads = num_heads
1087
+ else:
1088
+ assert (
1089
+ channels % num_head_channels == 0
1090
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
1091
+ self.num_heads = channels // num_head_channels
1092
+ self.use_checkpoint = use_checkpoint
1093
+
1094
+ self.temporal_length = temporal_length
1095
+ self.image_length = image_length
1096
+ self.use_relative_position = use_relative_position
1097
+ self.img_video_joint_train = img_video_joint_train
1098
+ self.attn_norm_type = attn_norm_type
1099
+ assert self.attn_norm_type in ["group", "no_norm"]
1100
+ self.use_tempoal_causal_attn = use_tempoal_causal_attn
1101
+
1102
+ if self.attn_norm_type == "group":
1103
+ self.norm_s = normalization(channels)
1104
+ self.norm_t = normalization(channels)
1105
+
1106
+ self.qkv_s = conv_nd(1, channels, channels * 3, 1)
1107
+ self.qkv_t = conv_nd(1, channels, channels * 3, 1)
1108
+
1109
+ if self.img_video_joint_train:
1110
+ mask = th.ones(
1111
+ [1, temporal_length + image_length, temporal_length + image_length]
1112
+ )
1113
+ mask[:, temporal_length:, :] = 0
1114
+ mask[:, :, temporal_length:] = 0
1115
+ self.register_buffer("mask", mask)
1116
+ else:
1117
+ self.mask = None
1118
+
1119
+ if use_new_attention_order:
1120
+ # split qkv before split heads
1121
+ self.attention_s = QKVAttention(self.num_heads)
1122
+ self.attention_t = QKVAttention(self.num_heads)
1123
+ else:
1124
+ # split heads before split qkv
1125
+ self.attention_s = QKVAttentionLegacy(self.num_heads)
1126
+ self.attention_t = QKVAttentionLegacy(self.num_heads)
1127
+
1128
+ if use_relative_position:
1129
+ self.relative_position_k = RelativePosition(
1130
+ num_units=channels // self.num_heads,
1131
+ max_relative_position=temporal_length,
1132
+ )
1133
+ self.relative_position_v = RelativePosition(
1134
+ num_units=channels // self.num_heads,
1135
+ max_relative_position=temporal_length,
1136
+ )
1137
+
1138
+ self.proj_out_s = zero_module(
1139
+ conv_nd(1, channels, channels, 1)
1140
+ ) # conv_dim, in_channels, out_channels, kernel_size
1141
+ self.proj_out_t = zero_module(
1142
+ conv_nd(1, channels, channels, 1)
1143
+ ) # conv_dim, in_channels, out_channels, kernel_size
1144
+
1145
+ def forward(self, x, mask=None):
1146
+ b, c, t, h, w = x.shape
1147
+
1148
+ # spatial
1149
+ out = rearrange(x, "b c t h w -> (b t) c (h w)")
1150
+ if self.attn_norm_type == "no_norm":
1151
+ qkv = self.qkv_s(out)
1152
+ else:
1153
+ qkv = self.qkv_s(self.norm_s(out))
1154
+ out = self.attention_s(qkv)
1155
+ out = self.proj_out_s(out)
1156
+ out = rearrange(out, "(b t) c (h w) -> b c t h w", b=b, h=h)
1157
+ x += out
1158
+
1159
+ # temporal
1160
+ out = rearrange(x, "b c t h w -> (b h w) c t")
1161
+ if self.attn_norm_type == "no_norm":
1162
+ qkv = self.qkv_t(out)
1163
+ else:
1164
+ qkv = self.qkv_t(self.norm_t(out))
1165
+
1166
+ # relative positional embedding
1167
+ if self.use_relative_position:
1168
+ len_q = qkv.size()[-1]
1169
+ len_k, len_v = len_q, len_q
1170
+ k_rp = self.relative_position_k(len_q, len_k)
1171
+ v_rp = self.relative_position_v(len_q, len_v) # [T,T,head_dim]
1172
+ out = self.attention_t(
1173
+ qkv,
1174
+ rp=(k_rp, v_rp),
1175
+ mask=self.mask,
1176
+ use_tempoal_causal_attn=self.use_tempoal_causal_attn,
1177
+ )
1178
+ else:
1179
+ out = self.attention_t(
1180
+ qkv,
1181
+ rp=None,
1182
+ mask=self.mask,
1183
+ use_tempoal_causal_attn=self.use_tempoal_causal_attn,
1184
+ )
1185
+
1186
+ out = self.proj_out_t(out)
1187
+ out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w)
1188
+
1189
+ return x + out
1190
+
1191
+
1192
+ # ---------------------------------------------------------------------------------------------------------------
1193
+
1194
+
1195
+ class QKVAttentionLegacy(nn.Module):
1196
+ """
1197
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
1198
+ """
1199
+
1200
+ def __init__(self, n_heads):
1201
+ super().__init__()
1202
+ self.n_heads = n_heads
1203
+
1204
+ def forward(self, qkv, rp=None, mask=None):
1205
+ """
1206
+ Apply QKV attention.
1207
+
1208
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
1209
+ :return: an [N x (H * C) x T] tensor after attention.
1210
+ """
1211
+ if rp is not None or mask is not None:
1212
+ raise NotImplementedError
1213
+ bs, width, length = qkv.shape
1214
+ assert width % (3 * self.n_heads) == 0
1215
+ ch = width // (3 * self.n_heads)
1216
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
1217
+ scale = 1 / math.sqrt(math.sqrt(ch))
1218
+ weight = th.einsum(
1219
+ "bct,bcs->bts", q * scale, k * scale
1220
+ ) # More stable with f16 than dividing afterwards
1221
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
1222
+ a = th.einsum("bts,bcs->bct", weight, v)
1223
+ return a.reshape(bs, -1, length)
1224
+
1225
+ @staticmethod
1226
+ def count_flops(model, _x, y):
1227
+ return count_flops_attn(model, _x, y)
1228
+
1229
+
1230
+ # ---------------------------------------------------------------------------------------------------------------
1231
+
1232
+
1233
+ class QKVAttention(nn.Module):
1234
+ """
1235
+ A module which performs QKV attention and splits in a different order.
1236
+ """
1237
+
1238
+ def __init__(self, n_heads):
1239
+ super().__init__()
1240
+ self.n_heads = n_heads
1241
+
1242
+ def forward(self, qkv, rp=None, mask=None, use_tempoal_causal_attn=False):
1243
+ """
1244
+ Apply QKV attention.
1245
+
1246
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
1247
+ :return: an [N x (H * C) x T] tensor after attention.
1248
+ """
1249
+ bs, width, length = qkv.shape
1250
+ assert width % (3 * self.n_heads) == 0
1251
+ ch = width // (3 * self.n_heads)
1252
+ # print('qkv', qkv.size())
1253
+ qkv=qkv.contiguous()
1254
+ q, k, v = qkv.chunk(3, dim=1)
1255
+ scale = 1 / math.sqrt(math.sqrt(ch))
1256
+ # print('bs, self.n_heads, ch, length', bs, self.n_heads, ch, length)
1257
+
1258
+ weight = th.einsum(
1259
+ "bct,bcs->bts",
1260
+ (q * scale).view(bs * self.n_heads, ch, length),
1261
+ (k * scale).view(bs * self.n_heads, ch, length),
1262
+ ) # More stable with f16 than dividing afterwards
1263
+ # weight:[b,t,s] b=bs*n_heads*T
1264
+
1265
+ if rp is not None:
1266
+ k_rp, v_rp = rp # [length, length, head_dim] [8, 8, 48]
1267
+ weight2 = th.einsum(
1268
+ "bct,tsc->bst", (q * scale).view(bs * self.n_heads, ch, length), k_rp
1269
+ )
1270
+ weight += weight2
1271
+
1272
+ if use_tempoal_causal_attn:
1273
+ # weight = torch.tril(weight)
1274
+ assert mask is None, f"Not implemented for merging two masks!"
1275
+ mask = torch.tril(torch.ones(weight.shape))
1276
+ else:
1277
+ if mask is not None: # only keep upper-left matrix
1278
+ # process mask
1279
+ c, t, _ = weight.shape
1280
+
1281
+ if mask.shape[-1] > t:
1282
+ mask = mask[:, :t, :t]
1283
+ elif mask.shape[-1] < t: # pad ones
1284
+ mask_ = th.zeros([c, t, t]).to(mask.device)
1285
+ t_ = mask.shape[-1]
1286
+ mask_[:, :t_, :t_] = mask
1287
+ mask = mask_
1288
+ else:
1289
+ assert (
1290
+ weight.shape[-1] == mask.shape[-1]
1291
+ ), f"weight={weight.shape}, mask={mask.shape}"
1292
+
1293
+ if mask is not None:
1294
+ INF = -1e8 # float('-inf')
1295
+ weight = weight.float().masked_fill(mask == 0, INF)
1296
+
1297
+ weight = F.softmax(weight.float(), dim=-1).type(
1298
+ weight.dtype
1299
+ ) # [256, 8, 8] [b, t, t] b=bs*n_heads*h*w,t=nframes
1300
+ # weight = F.softmax(weight, dim=-1)#[256, 8, 8] [b, t, t] b=bs*n_heads*h*w,t=nframes
1301
+ a = th.einsum(
1302
+ "bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)
1303
+ ) # [256, 48, 8] [b, head_dim, t]
1304
+
1305
+ if rp is not None:
1306
+ a2 = th.einsum("bts,tsc->btc", weight, v_rp).transpose(1, 2) # btc->bct
1307
+ a += a2
1308
+
1309
+ return a.reshape(bs, -1, length)
1310
+
1311
+
1312
+ # ---------------------------------------------------------------------------------------------------------------
1313
+
1314
+ # ---------------------------------------------------------------------------------------------------------------
base_encoder.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class BaseVisionTower(nn.Module):
8
+ def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower_name
14
+ self.delay_load = delay_load
15
+
16
+ @abstractmethod
17
+ def load_model(self, device_map=None):
18
+ raise NotImplementedError("Subclasses must implement load_model")
19
+
20
+ @abstractmethod
21
+ def _forward(self, images):
22
+ raise NotImplementedError("Subclasses must implement forward")
23
+
24
+ def forward(self, images):
25
+ if type(images) is list:
26
+ image_features = [self._forward(image.unsqueeze(0)) for image in images]
27
+ else:
28
+ image_features = self._forward(images)
29
+
30
+ return image_features
31
+
32
+ @property
33
+ def dummy_feature(self):
34
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
35
+
36
+ @property
37
+ def dtype(self):
38
+ # Dynamically infer the dtype from the first parameter, if not explicitly specified
39
+ if hasattr(self.vision_tower, "dtype"):
40
+ return self.vision_tower.dtype
41
+ else:
42
+ params = list(self.vision_tower.parameters())
43
+ return (
44
+ params[0].dtype if len(params) > 0 else torch.float32
45
+ ) # Default to torch.float32 if no parameters
46
+
47
+ @property
48
+ def device(self):
49
+ # Dynamically infer the device from the first parameter, if not explicitly specified
50
+ if hasattr(self.vision_tower, "device"):
51
+ return self.vision_tower.device
52
+ else:
53
+ params = list(self.vision_tower.parameters())
54
+ return (
55
+ params[0].device if len(params) > 0 else torch.device("cpu")
56
+ ) # Default to CPU if no parameters
57
+ @property
58
+ def config(self):
59
+ if self.is_loaded:
60
+ return self.vision_tower.config
61
+ else:
62
+ return self.cfg_only
63
+ @property
64
+ def hidden_size(self):
65
+ try:
66
+ return self.config.hidden_size
67
+ except:
68
+ return self._hidden_size
builder.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .siglip_encoder import SigLipVisionTower
3
+
4
+
5
+ def build_vision_tower(vision_tower_cfg, **kwargs):
6
+
7
+ vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
8
+ is_absolute_path_exists = os.path.exists(vision_tower)
9
+ use_s2 = getattr(vision_tower_cfg, "s2", False)
10
+
11
+ #print(getattr(vision_tower_cfg, "vision_tower", None))
12
+ return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
13
+ if getattr(vision_tower_cfg, "vision_tower", None) and "siglip" in getattr(vision_tower_cfg, "vision_tower", None).lower():
14
+ #print('*************\n')
15
+ return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
16
+
17
+ raise ValueError(f"Unknown vision tower: {vision_tower}")
llava_arch.py CHANGED
@@ -14,25 +14,48 @@
14
 
15
 
16
  from abc import ABC, abstractmethod
17
-
 
18
  import math
19
  import re
20
  import time
21
  import torch
22
  import torch.nn as nn
23
  import torch.nn.functional as F
24
- from .multimodal_encoder.builder import build_vision_tower
25
- from .multimodal_resampler.builder import build_vision_resampler
26
- from .multimodal_projector.builder import build_vision_projector
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  from transformers import AutoTokenizer
28
 
29
- from longva.longva.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
30
 
31
- from longva.longva.mm_utils import get_anyres_image_grid_shape
32
- from longva.longva.utils import rank0_print
33
  import random
34
  from .sae import SiglipAE
35
- from .WindowTimeToTokenAttention import WindowTimeToTokenAttention
36
  import numpy as np
37
  import torch.nn.functional as F
38
  import pdb
@@ -281,15 +304,13 @@ class LlavaMetaForCausalLM(ABC):
281
  return expanded_x
282
 
283
  def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None):
284
- #################################################################################
285
- # if videos_or_images.shape[0] > 360:
286
- # random_indices = np.random.choice(videos_or_images.shape[0], size=360, replace=False)
287
- # videos_or_images = videos_or_images[random_indices]
288
- # split_sizes=videos_or_images.shape[0]
289
-
290
- #################################################################################
291
  # Define the maximum batch size (1024 frames)
292
- max_batch_size = 60
293
  num_frames = videos_or_images.shape[0]
294
  # Initialize a list to store the features from each batch
295
  videos_or_images_features = []
@@ -312,47 +333,49 @@ class LlavaMetaForCausalLM(ABC):
312
  else:
313
  videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images)
314
 
315
- per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0) # tuple, (dim_1, 576, 4096)
316
  all_videos_or_images_features = []
317
-
318
-
 
 
 
 
 
 
 
319
  for idx, feat in enumerate(per_videos_or_images_features):
320
- #print(feat.shape,end='1\n')
321
- feat=self.interpolate(feat)
322
- #######################################################
323
- if idx in video_idx_in_batch:
324
- feat=self.add_video(feat)
325
- else:
326
- feat=self.add_image(feat)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
- bc,ch,h,w=feat.shape
329
-
330
- feat = feat.view(bc//4,ch,4,h,w)
331
- if bc//4>24:
332
- chunk_size = 24
333
- chunks = torch.split(feat, chunk_size, dim=0)
334
- interpolated_chunks = []
335
- for chunk in chunks:
336
- interpolated_chunk=self.get_model().sae(chunk).squeeze(2)
337
- interpolated_chunks.append(interpolated_chunk)
338
- feat = torch.cat(interpolated_chunks, dim=0)
339
- del interpolated_chunks
340
- del chunks
341
- else:
342
- feat=self.get_model().sae(feat).squeeze(2)
343
- feat = feat.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
344
- #print(feat.shape,end='3\n')
345
- feat = self.get_model().mm_projector(feat)
346
- #print(feat.shape,end='4\n')
347
- # Post pooling
348
- if idx in video_idx_in_batch:
349
- #print('************************',idx,video_idx_in_batch)
350
- feat = self.get_2dPool(feat)
351
- all_videos_or_images_features.append(feat)
352
-
353
  del per_videos_or_images_features
 
 
 
 
354
  return all_videos_or_images_features
355
- ########################################################
 
356
  def interpolate(self,image_features):
357
  b, num_tokens, dim = image_features.shape
358
 
@@ -383,6 +406,7 @@ class LlavaMetaForCausalLM(ABC):
383
  return image_features
384
 
385
  def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None,time_embedding=None):
 
386
  vision_tower = self.get_vision_tower()
387
  if vision_tower is None or images is None or input_ids.shape[1] == 1:
388
  return input_ids, position_ids, attention_mask, past_key_values, None, labels
 
14
 
15
 
16
  from abc import ABC, abstractmethod
17
+ import importlib.util
18
+ import os.path as osp
19
  import math
20
  import re
21
  import time
22
  import torch
23
  import torch.nn as nn
24
  import torch.nn.functional as F
25
+
26
+ try:
27
+ from .builder import build_vision_tower
28
+ from .builder import build_vision_resampler
29
+ from .builder import build_vision_projector
30
+ except ModuleNotFoundError:
31
+ spec = importlib.util.spec_from_file_location(
32
+ "builder",
33
+ osp.join(osp.dirname(__file__), "builder.py"),
34
+ )
35
+ builder = importlib.util.module_from_spec(spec)
36
+ spec.loader.exec_module(builder)
37
+ build_vision_tower = getattr(
38
+ builder,
39
+ "build_vision_tower",
40
+ )
41
+ build_vision_resampler = getattr(
42
+ builder,
43
+ "build_vision_resampler",
44
+ )
45
+ build_vision_projector = getattr(
46
+ builder,
47
+ "build_vision_projector",
48
+ )
49
+
50
+
51
  from transformers import AutoTokenizer
52
 
53
+ from .constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
54
 
55
+ from .mm_utils import get_anyres_image_grid_shape
56
+ from .utils import rank0_print
57
  import random
58
  from .sae import SiglipAE
 
59
  import numpy as np
60
  import torch.nn.functional as F
61
  import pdb
 
304
  return expanded_x
305
 
306
  def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None):
307
+ pdb.set_trace()
308
+ if self.config.enable_chunk_prefill:
309
+ chunk_size_for_vision_tower = self.config.prefill_config['chunk_size_for_vision_tower']
310
+ else:
311
+ chunk_size_for_vision_tower = 100000
 
 
312
  # Define the maximum batch size (1024 frames)
313
+ max_batch_size = chunk_size_for_vision_tower
314
  num_frames = videos_or_images.shape[0]
315
  # Initialize a list to store the features from each batch
316
  videos_or_images_features = []
 
333
  else:
334
  videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images)
335
 
336
+ per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0)
337
  all_videos_or_images_features = []
338
+
339
+ peak_memory_allocated = torch.cuda.max_memory_allocated()
340
+ print(f"vision encoder 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
341
+
342
+ del videos_or_images_features
343
+ torch.cuda.empty_cache()
344
+
345
+ chunk_size = chunk_size_for_vision_tower
346
+ all_feat_list = []
347
  for idx, feat in enumerate(per_videos_or_images_features):
348
+ for i in range(0, feat.shape[0], chunk_size):
349
+ batched_feat = feat[i:i+chunk_size]
350
+ batched_feat=self.interpolate(batched_feat) # torch.Size([187, 1152, 24, 24])
351
+ if idx in video_idx_in_batch:
352
+ batched_feat = self.add_video(batched_feat) # torch.Size([188, 1152, 24, 24])
353
+ else:
354
+ batched_feat = self.add_image(batched_feat)
355
+
356
+ bc,ch,h,w = batched_feat.shape
357
+ batched_feat = batched_feat.view(bc//4,ch,4,h,w)
358
+
359
+ batched_feat=self.get_model().sae(batched_feat).squeeze(2)
360
+ batched_feat = batched_feat.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
361
+ batched_feat = self.get_model().mm_projector(batched_feat)
362
+
363
+
364
+ batched_feat = self.get_2dPool(batched_feat)
365
+ all_feat_list.append(batched_feat)
366
+
367
+ feat = torch.cat(all_feat_list, dim=0)
368
+ peak_memory_allocated = torch.cuda.max_memory_allocated()
369
+ print(f"sae 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  del per_videos_or_images_features
372
+ del all_feat_list
373
+ torch.cuda.empty_cache()
374
+
375
+ all_videos_or_images_features.append(feat)
376
  return all_videos_or_images_features
377
+
378
+
379
  def interpolate(self,image_features):
380
  b, num_tokens, dim = image_features.shape
381
 
 
406
  return image_features
407
 
408
  def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None,time_embedding=None):
409
+ pdb.set_trace()
410
  vision_tower = self.get_vision_tower()
411
  if vision_tower is None or images is None or input_ids.shape[1] == 1:
412
  return input_ids, position_ids, attention_mask, past_key_values, None, labels
llava_qwen.py CHANGED
@@ -21,7 +21,7 @@ import transformers
21
  from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
22
  from transformers.modeling_outputs import CausalLMOutputWithPast
23
  from transformers.generation.utils import GenerateOutput
24
- from longva.longva.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
25
  from .modeling_qwen2 import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
26
  import pdb
27
  import time
@@ -211,6 +211,7 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
211
  time_token_end_indices=None,
212
  block_size_chosed=None,
213
  prev_blocks_num=None,
 
214
  ) -> Union[Tuple, CausalLMOutputWithPast]:
215
 
216
  block_size = block_size_chosed
@@ -218,7 +219,6 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
218
  visual_token_end_pos = visual_token_end_pos
219
  visual_len = visual_token_end_pos - visual_token_start_pos
220
  num_blocks = (frames_num + block_size * 4 - 1) // (block_size * 4)
221
- # print(f'block_size: {block_size}, num_blocks: {num_blocks}')
222
 
223
  # streaming inps
224
  blocks_positions = [[(0, 0, visual_token_start_pos)]]
@@ -254,10 +254,10 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
254
  suffix_embeds = full_inputs_embeds[:, visual_token_end_pos:, :]
255
  num_visual_tokens = visual_embeds.size(1)
256
 
257
- all_past_key_values = [[] for _ in range(len(self.model.layers))] # 假设 model 有 layers 属性
258
  prefix_past_key_values = []
259
 
260
- torch.cuda.reset_peak_memory_stats()
261
 
262
  if prefix_embeds.size(1) > 0:
263
  pkv = self.process_block(prefix_embeds, bsz=bsz, device=device)
@@ -288,16 +288,15 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
288
 
289
  block_streaming_past_key_values_part1 = prefix_past_key_values
290
  position_ids_part1 = torch.arange(0, prefix_past_key_values[0][0].size(2), dtype=torch.long, device=device)
291
- block_streaming_past_key_values_part2 = [[] for _ in range(len(self.model.layers))] # 存
292
  position_ids_part2 = torch.tensor([], dtype=torch.long, device=device)
293
  block_streaming_past_key_values_part3=None
294
  position_ids_part3 = None
295
 
296
  query_position_ids = None
297
  for idx, single_block in enumerate(blocks_positions[:]):
298
- if idx == 0:
299
- continue
300
- if idx <= prev_blocks_num:
301
  continue
302
 
303
  b_start, _, _ = single_block[0]
@@ -312,13 +311,15 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
312
  true_block_length = b_end - b_start
313
 
314
  block_streaming_past_key_values_part3 = [tmp[-prev_blocks_num:] for tmp in all_past_key_values]
315
- # block_streaming_past_key_values_part3 = [
316
- # [
317
- # (t[0].to(device=device), t[1].to(device=device))
318
- # for t in sublist
319
- # ]
320
- # for sublist in block_streaming_past_key_values_part3
321
- # ]
 
 
322
 
323
  block_streaming_past_key_values = self.cat_history_kvs(block_streaming_past_key_values_part1, block_streaming_past_key_values_part2, block_streaming_past_key_values_part3)
324
 
@@ -337,8 +338,11 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
337
  key_this_block, val_this_block = pkv[i]
338
  key_this_block = key_this_block[:,:,length_before_chunk:,:]
339
  val_this_block = val_this_block[:,:,length_before_chunk:,:]
340
- all_past_key_values[i].append( (key_this_block, val_this_block) )
341
- # all_past_key_values[i].append( (key_this_block.to('cpu'), val_this_block.to('cpu')) )
 
 
 
342
 
343
  time_keys_list = []
344
  time_vals_list = []
@@ -371,6 +375,9 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
371
  values = torch.cat([pkv[1].to(device=device) for pkv in layer_pkvs], dim=2)
372
  merged_pkv.append((keys, values))
373
 
 
 
 
374
 
375
  pkv = merged_pkv
376
  del block_streaming_past_key_values
@@ -383,6 +390,8 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
383
  # TODO: bi-decoding acceleration
384
  mixed_prefill_past_key_values = pkv
385
  prefill_len = visual_token_end_pos
 
 
386
 
387
  # Process suffix
388
  if suffix_embeds.size(1) > 0:
@@ -404,6 +413,8 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
404
  return_dict=return_dict,
405
  # blocks_positions=None,
406
  )
 
 
407
  del mixed_prefill_past_key_values
408
  torch.cuda.empty_cache()
409
 
@@ -508,12 +519,17 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
508
  )
509
 
510
  if inputs_embeds is None:
 
511
  (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes, time_embedding)
512
 
513
- if self.config.enable_sparse:
514
- block_size_chosed = self.config.sparse_config['block_size_chosed']
515
- prev_blocks_num = self.config.sparse_config['prev_blocks_num']
516
- if self.config.sparse_mode=='streaming':
 
 
 
 
517
  return self.forward_streaming(
518
  input_ids=input_ids,
519
  attention_mask=attention_mask,
@@ -533,10 +549,11 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
533
  frames_num=frames_num,
534
  time_token_indices=time_token_indices,
535
  time_token_end_indices=time_token_end_indices,
536
- block_size_chosed=block_size_chosed,
537
- prev_blocks_num=prev_blocks_num,
 
538
  )
539
- elif self.config.sparse_mode=='mask':
540
  return self.forward_mask(
541
  input_ids=input_ids,
542
  attention_mask=attention_mask,
@@ -584,6 +601,8 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
584
  **kwargs,
585
  ) -> Union[GenerateOutput, torch.LongTensor]:
586
 
 
 
587
  position_ids = kwargs.pop("position_ids", None)
588
  attention_mask = kwargs.pop("attention_mask", None)
589
 
@@ -631,6 +650,7 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
631
  sample_fps=1,
632
  max_sample_fps=4,
633
  generation_config={}):
 
634
 
635
  # prepare text input
636
  conv = conv_templates["qwen_1_5"].copy()
 
21
  from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
22
  from transformers.modeling_outputs import CausalLMOutputWithPast
23
  from transformers.generation.utils import GenerateOutput
24
+ from .llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
25
  from .modeling_qwen2 import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
26
  import pdb
27
  import time
 
211
  time_token_end_indices=None,
212
  block_size_chosed=None,
213
  prev_blocks_num=None,
214
+ offload: Optional[bool] = None,
215
  ) -> Union[Tuple, CausalLMOutputWithPast]:
216
 
217
  block_size = block_size_chosed
 
219
  visual_token_end_pos = visual_token_end_pos
220
  visual_len = visual_token_end_pos - visual_token_start_pos
221
  num_blocks = (frames_num + block_size * 4 - 1) // (block_size * 4)
 
222
 
223
  # streaming inps
224
  blocks_positions = [[(0, 0, visual_token_start_pos)]]
 
254
  suffix_embeds = full_inputs_embeds[:, visual_token_end_pos:, :]
255
  num_visual_tokens = visual_embeds.size(1)
256
 
257
+ all_past_key_values = [[] for _ in range(len(self.model.layers))]
258
  prefix_past_key_values = []
259
 
260
+ # torch.cuda.reset_peak_memory_stats()
261
 
262
  if prefix_embeds.size(1) > 0:
263
  pkv = self.process_block(prefix_embeds, bsz=bsz, device=device)
 
288
 
289
  block_streaming_past_key_values_part1 = prefix_past_key_values
290
  position_ids_part1 = torch.arange(0, prefix_past_key_values[0][0].size(2), dtype=torch.long, device=device)
291
+ block_streaming_past_key_values_part2 = [[] for _ in range(len(self.model.layers))]
292
  position_ids_part2 = torch.tensor([], dtype=torch.long, device=device)
293
  block_streaming_past_key_values_part3=None
294
  position_ids_part3 = None
295
 
296
  query_position_ids = None
297
  for idx, single_block in enumerate(blocks_positions[:]):
298
+
299
+ if idx == 0 or idx <= prev_blocks_num:
 
300
  continue
301
 
302
  b_start, _, _ = single_block[0]
 
311
  true_block_length = b_end - b_start
312
 
313
  block_streaming_past_key_values_part3 = [tmp[-prev_blocks_num:] for tmp in all_past_key_values]
314
+
315
+ if offload:
316
+ block_streaming_past_key_values_part3 = [
317
+ [
318
+ (t[0].to(device=device), t[1].to(device=device))
319
+ for t in sublist
320
+ ]
321
+ for sublist in block_streaming_past_key_values_part3
322
+ ]
323
 
324
  block_streaming_past_key_values = self.cat_history_kvs(block_streaming_past_key_values_part1, block_streaming_past_key_values_part2, block_streaming_past_key_values_part3)
325
 
 
338
  key_this_block, val_this_block = pkv[i]
339
  key_this_block = key_this_block[:,:,length_before_chunk:,:]
340
  val_this_block = val_this_block[:,:,length_before_chunk:,:]
341
+
342
+ if offload:
343
+ all_past_key_values[i].append( (key_this_block.to('cpu'), val_this_block.to('cpu')) )
344
+ else:
345
+ all_past_key_values[i].append( (key_this_block, val_this_block) )
346
 
347
  time_keys_list = []
348
  time_vals_list = []
 
375
  values = torch.cat([pkv[1].to(device=device) for pkv in layer_pkvs], dim=2)
376
  merged_pkv.append((keys, values))
377
 
378
+ peak_memory_allocated = torch.cuda.max_memory_allocated()
379
+ print(f"prefill 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
380
+
381
 
382
  pkv = merged_pkv
383
  del block_streaming_past_key_values
 
390
  # TODO: bi-decoding acceleration
391
  mixed_prefill_past_key_values = pkv
392
  prefill_len = visual_token_end_pos
393
+
394
+ # torch.cuda.reset_peak_memory_stats()
395
 
396
  # Process suffix
397
  if suffix_embeds.size(1) > 0:
 
413
  return_dict=return_dict,
414
  # blocks_positions=None,
415
  )
416
+ peak_memory_allocated = torch.cuda.max_memory_allocated()
417
+ print(f"decoding 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
418
  del mixed_prefill_past_key_values
419
  torch.cuda.empty_cache()
420
 
 
519
  )
520
 
521
  if inputs_embeds is None:
522
+ pdb.set_trace()
523
  (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes, time_embedding)
524
 
525
+ if self.config.enable_chunk_prefill:
526
+
527
+ prefill_mode = self.config.prefill_config['chunk_prefill_mode']
528
+ chunk_size = self.config.prefill_config['chunk_size']
529
+ step_size = self.config.prefill_config['step_size']
530
+ offload = self.config.prefill_config['offload']
531
+
532
+ if prefill_mode=='streaming':
533
  return self.forward_streaming(
534
  input_ids=input_ids,
535
  attention_mask=attention_mask,
 
549
  frames_num=frames_num,
550
  time_token_indices=time_token_indices,
551
  time_token_end_indices=time_token_end_indices,
552
+ block_size_chosed=chunk_size,
553
+ prev_blocks_num=chunk_size - step_size,
554
+ offload=offload,
555
  )
556
+ elif prefill_mode=='mask':
557
  return self.forward_mask(
558
  input_ids=input_ids,
559
  attention_mask=attention_mask,
 
601
  **kwargs,
602
  ) -> Union[GenerateOutput, torch.LongTensor]:
603
 
604
+
605
+
606
  position_ids = kwargs.pop("position_ids", None)
607
  attention_mask = kwargs.pop("attention_mask", None)
608
 
 
650
  sample_fps=1,
651
  max_sample_fps=4,
652
  generation_config={}):
653
+ pdb.set_trace()
654
 
655
  # prepare text input
656
  conv = conv_templates["qwen_1_5"].copy()
mm_utils.py CHANGED
@@ -419,6 +419,7 @@ class KeywordsStoppingCriteria(StoppingCriteria):
419
 
420
  from decord import VideoReader, cpu
421
  def load_video(video_path, max_frames_num, fps=1, max_fps=4):
 
422
  if isinstance(video_path, str):
423
  vr = VideoReader(video_path, ctx=cpu(0))
424
  else:
@@ -431,22 +432,25 @@ def load_video(video_path, max_frames_num, fps=1, max_fps=4):
431
  return None, None, []
432
 
433
  video_fps = fps
434
- step = round(avg_fps_from_decord / video_fps) if video_fps > 0 and avg_fps_from_decord > 0 else 1
435
- frame_idx = [i for i in range(0, total_frame_num, step)]
436
-
437
  fps_upbound = max_fps
438
  frames_upbound = max_frames_num
439
-
440
- if fps_upbound is not None:
441
- higher_fps = min(frames_upbound//len(frame_idx), fps_upbound)
442
- if higher_fps > video_fps:
443
- higher_steps = round(avg_fps_from_decord / higher_fps)
444
- frame_idx = [i for i in range(0, total_frame_num, higher_steps)]
445
-
446
- if frames_upbound > 0:
447
- if len(frame_idx) > frames_upbound:
448
- uniform_sampled_frames = np.linspace(0, total_frame_num - 1, frames_upbound, dtype=int)
449
- frame_idx = uniform_sampled_frames.tolist()
 
 
 
 
 
 
450
 
451
  timestamps = [round(idx / avg_fps_from_decord, 1) for idx in frame_idx]
452
  video = vr.get_batch(frame_idx).asnumpy()
 
419
 
420
  from decord import VideoReader, cpu
421
  def load_video(video_path, max_frames_num, fps=1, max_fps=4):
422
+
423
  if isinstance(video_path, str):
424
  vr = VideoReader(video_path, ctx=cpu(0))
425
  else:
 
432
  return None, None, []
433
 
434
  video_fps = fps
 
 
 
435
  fps_upbound = max_fps
436
  frames_upbound = max_frames_num
437
+ if fps is not None:
438
+ step = round(avg_fps_from_decord / video_fps) if video_fps > 0 and avg_fps_from_decord > 0 else 1
439
+ frame_idx = [i for i in range(0, total_frame_num, step)]
440
+
441
+ if fps_upbound is not None:
442
+ higher_fps = min(frames_upbound//len(frame_idx), fps_upbound)
443
+ if higher_fps > video_fps:
444
+ higher_steps = round(avg_fps_from_decord / higher_fps)
445
+ frame_idx = [i for i in range(0, total_frame_num, higher_steps)]
446
+
447
+ if frames_upbound > 0:
448
+ if len(frame_idx) > frames_upbound:
449
+ uniform_sampled_frames = np.linspace(0, total_frame_num - 1, frames_upbound, dtype=int)
450
+ frame_idx = uniform_sampled_frames.tolist()
451
+ else: # use uiform sample
452
+ uniform_sampled_frames = np.linspace(0, total_frame_num - 1, frames_upbound, dtype=int)
453
+ frame_idx = uniform_sampled_frames.tolist()
454
 
455
  timestamps = [round(idx / avg_fps_from_decord, 1) for idx in frame_idx]
456
  video = vr.get_batch(frame_idx).asnumpy()
modeling_qwen2.py CHANGED
@@ -688,7 +688,10 @@ class Qwen2SdpaAttention(Qwen2Attention):
688
 
689
  try:
690
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids, key_position_ids)
691
- except:
 
 
 
692
  pdb.set_trace()
693
  key_states = repeat_kv(key_states, self.num_key_value_groups)
694
  value_states = repeat_kv(value_states, self.num_key_value_groups)
 
688
 
689
  try:
690
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids, key_position_ids)
691
+ except Exception as e:
692
+ print(e)
693
+ import traceback
694
+ traceback.print_exc()
695
  pdb.set_trace()
696
  key_states = repeat_kv(key_states, self.num_key_value_groups)
697
  value_states = repeat_kv(value_states, self.num_key_value_groups)
sae.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .sae_utils import SamePadConv3d,Normalize,SiLU,TemporalAttention,AttnBlock3D,MultiHeadAttention3D,TemporalAttention_lin
4
+ import torch.nn as nn
5
+ import pdb
6
+
7
+ class SiglipAE(nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+ temporal_stride=2
11
+ norm_type = "group"
12
+
13
+ self.temporal_encoding = nn.Parameter(torch.randn((4,1152)))
14
+ #self.vision_tower=SigLipVisionTower('google/siglip-so400m-patch14-384')
15
+ self.encoder=nn.Sequential(
16
+ AttnBlock3D(1152),
17
+ TemporalAttention(1152),
18
+
19
+ SamePadConv3d(1152,1152,kernel_size=3,stride=(temporal_stride, 1, 1),padding_type="replicate"),
20
+
21
+ AttnBlock3D(1152),
22
+ TemporalAttention(1152),
23
+
24
+ SamePadConv3d(1152,1152,kernel_size=3,stride=(temporal_stride, 1, 1),padding_type="replicate"),
25
+
26
+ )
27
+ def forward(self, x):
28
+ b_,c_,t_,h_,w_=x.shape
29
+
30
+ temporal_encoding = self.temporal_encoding.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
31
+ temporal_encoding = temporal_encoding.expand(b_, -1, -1, h_, w_) # (B, T, C, H, W)
32
+ temporal_encoding = temporal_encoding.permute(0, 2, 1, 3, 4) # (B, C, T, H, W)
33
+ x = x + temporal_encoding
34
+
35
+ x=self.encoder(x)
36
+ return x
37
+ # image=torch.randn(1,1152,4,24,24).to('cuda')
38
+
39
+
40
+ # model = SiglipAE().to('cuda')
41
+ # model.load_state_dict(torch.load('encoder.pth'),strict=False)
42
+
43
+ # image=model(image)
44
+
45
+ # print(image.shape)
sae_utils.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers.activations import ACT2FN
5
+ from .attention_temporal_videoae import *
6
+ from einops import rearrange, reduce, repeat
7
+
8
+ try:
9
+ import xformers
10
+ import xformers.ops as xops
11
+
12
+ XFORMERS_IS_AVAILBLE = True
13
+ except:
14
+ XFORMERS_IS_AVAILBLE = False
15
+
16
+ def silu(x):
17
+ # swish
18
+ return x * torch.sigmoid(x)
19
+
20
+
21
+ class SiLU(nn.Module):
22
+ def __init__(self):
23
+ super(SiLU, self).__init__()
24
+
25
+ def forward(self, x):
26
+ return silu(x)
27
+
28
+
29
+ def Normalize(in_channels, norm_type="group"):
30
+ assert norm_type in ["group", "batch",'layer']
31
+ if norm_type == "group":
32
+ return torch.nn.GroupNorm(
33
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
34
+ )
35
+ elif norm_type == "batch":
36
+ return torch.nn.SyncBatchNorm(in_channels)
37
+ elif norm_type == "layer":
38
+ return nn.LayerNorm(in_channels)
39
+
40
+ class SamePadConv3d(nn.Module):
41
+ def __init__(
42
+ self,
43
+ in_channels,
44
+ out_channels,
45
+ kernel_size,
46
+ stride=1,
47
+ bias=True,
48
+ padding_type="replicate",
49
+ ):
50
+ super().__init__()
51
+ if isinstance(kernel_size, int):
52
+ kernel_size = (kernel_size,) * 3
53
+ if isinstance(stride, int):
54
+ stride = (stride,) * 3
55
+
56
+ # assumes that the input shape is divisible by stride
57
+ total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
58
+ pad_input = []
59
+ for p in total_pad[::-1]: # reverse since F.pad starts from last dim
60
+ pad_input.append((p // 2 + p % 2, p // 2))
61
+ pad_input = sum(pad_input, tuple())
62
+
63
+ self.pad_input = pad_input
64
+ self.padding_type = padding_type
65
+
66
+ self.conv = nn.Conv3d(
67
+ in_channels, out_channels, kernel_size, stride=stride, padding=0, bias=bias
68
+ )
69
+
70
+ def forward(self, x):
71
+ tp=x.dtype
72
+ x = x.float()
73
+
74
+ # 执行填充操作
75
+ x_padded = F.pad(x, self.pad_input, mode=self.padding_type)
76
+
77
+ # 如果需要,将结果转换回 BFloat16
78
+ x_padded = x_padded.to(tp)
79
+
80
+ return self.conv(x_padded)
81
+
82
+ class TemporalAttention(nn.Module):
83
+ def __init__(
84
+ self,
85
+ channels,
86
+ num_heads=1,
87
+ num_head_channels=-1,
88
+ max_temporal_length=64,
89
+ ):
90
+ """
91
+ a clean multi-head temporal attention
92
+ """
93
+ super().__init__()
94
+
95
+ if num_head_channels == -1:
96
+ self.num_heads = num_heads
97
+ else:
98
+ assert (
99
+ channels % num_head_channels == 0
100
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
101
+ self.num_heads = channels // num_head_channels
102
+
103
+ self.norm = Normalize(channels)
104
+ self.qkv = zero_module(conv_nd(1, channels, channels * 3, 1))
105
+ self.attention = QKVAttention(self.num_heads)
106
+ self.relative_position_k = RelativePosition(
107
+ num_units=channels // self.num_heads,
108
+ max_relative_position=max_temporal_length,
109
+ )
110
+ self.relative_position_v = RelativePosition(
111
+ num_units=channels // self.num_heads,
112
+ max_relative_position=max_temporal_length,
113
+ )
114
+ self.proj_out = zero_module(
115
+ conv_nd(1, channels, channels, 1)
116
+ ) # conv_dim, in_channels, out_channels, kernel_size
117
+
118
+ def forward(self, x, mask=None):
119
+ b, c, t, h, w = x.shape
120
+ out = rearrange(x, "b c t h w -> (b h w) c t")
121
+ # torch.Size([4608, 1152, 2])1
122
+ # torch.Size([4608, 3456, 2])2
123
+ # torch.Size([4608, 1152, 2])3
124
+ # torch.Size([4608, 1152, 2])4
125
+ #print(out.shape,end='1\n')
126
+ qkv = self.qkv(self.norm(out))
127
+ #print(qkv.shape,end='2\n')
128
+
129
+ len_q = qkv.size()[-1]
130
+ len_k, len_v = len_q, len_q
131
+
132
+ k_rp = self.relative_position_k(len_q, len_k)
133
+ v_rp = self.relative_position_v(len_q, len_v) # [T,T,head_dim]
134
+ out = self.attention(qkv, rp=(k_rp, v_rp))
135
+ #print(out.shape,end='3\n')
136
+ out = self.proj_out(out)
137
+ #print(out.shape,end='4\n')
138
+ out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w)
139
+
140
+ return x + out
141
+ class TemporalAttention_lin(nn.Module):
142
+ def __init__(
143
+ self,
144
+ channels,
145
+ num_heads=8,
146
+ num_head_channels=-1,
147
+ max_temporal_length=64,
148
+ ):
149
+ """
150
+ a clean multi-head temporal attention
151
+ """
152
+ super().__init__()
153
+
154
+ if num_head_channels == -1:
155
+ self.num_heads = num_heads
156
+ else:
157
+ assert (
158
+ channels % num_head_channels == 0
159
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
160
+ self.num_heads = channels // num_head_channels
161
+
162
+ self.norm = nn.LayerNorm(channels)
163
+ #self.norm = Normalize(channels)
164
+ #self.qkv = zero_module(conv_nd(1, channels, channels * 3, 1))
165
+ self.qkv = nn.Linear(channels, channels * 3)
166
+ self.attention = QKVAttention(self.num_heads)
167
+ self.relative_position_k = RelativePosition(
168
+ num_units=channels // self.num_heads,
169
+ max_relative_position=max_temporal_length,
170
+ )
171
+ self.relative_position_v = RelativePosition(
172
+ num_units=channels // self.num_heads,
173
+ max_relative_position=max_temporal_length,
174
+ )
175
+ self.proj_out = nn.Linear(channels, channels)
176
+
177
+ def forward(self, x, mask=None):
178
+ b, c, t, h, w = x.shape
179
+ out = rearrange(x, "b c t h w -> (b h w) t c")
180
+ # torch.Size([4608, 1152, 2])1
181
+ # torch.Size([4608, 3456, 2])2
182
+ # torch.Size([4608, 1152, 2])3
183
+ # torch.Size([4608, 1152, 2])4
184
+ #print(out.shape,end='1\n')
185
+ qkv = self.qkv(self.norm(out)).transpose(-1, -2)
186
+ #print(qkv.shape,end='2\n')
187
+
188
+ len_q = qkv.size()[-1]
189
+ len_k, len_v = len_q, len_q
190
+
191
+ k_rp = self.relative_position_k(len_q, len_k)
192
+ v_rp = self.relative_position_v(len_q, len_v) # [T,T,head_dim]
193
+
194
+ out = self.attention(qkv, rp=(k_rp, v_rp))
195
+
196
+ out = self.proj_out(out.transpose(-1, -2)).transpose(-1, -2)
197
+
198
+ #print(out.shape,end='4\n')
199
+ out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w)
200
+
201
+ return x + out
202
+
203
+ class AttnBlock3D(nn.Module):
204
+ def __init__(self, in_channels):
205
+ super().__init__()
206
+ self.in_channels = in_channels
207
+
208
+ self.norm = Normalize(in_channels)
209
+ self.q = torch.nn.Conv3d(
210
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
211
+ )
212
+ self.k = torch.nn.Conv3d(
213
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
214
+ )
215
+ self.v = torch.nn.Conv3d(
216
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
217
+ )
218
+ self.proj_out = torch.nn.Conv3d(
219
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
220
+ )
221
+
222
+ def forward(self, x):
223
+ h_ = x
224
+ # self.norm.to(x.device)
225
+ # self.norm.to(x.dtype)
226
+ h_ = self.norm(h_)
227
+ q = self.q(h_)
228
+ k = self.k(h_)
229
+ v = self.v(h_)
230
+
231
+ b, c, t, h, w = q.shape
232
+ # q = q.reshape(b,c,h*w) # bcl
233
+ # q = q.permute(0,2,1) # bcl -> blc l=hw
234
+ # k = k.reshape(b,c,h*w) # bcl
235
+ q = rearrange(q, "b c t h w -> (b t) (h w) c") # blc
236
+ k = rearrange(k, "b c t h w -> (b t) c (h w)") # bcl
237
+
238
+ w_ = torch.bmm(q, k) # b,l,l
239
+ w_ = w_ * (int(c) ** (-0.5))
240
+ w_ = torch.nn.functional.softmax(w_, dim=2)
241
+
242
+ # v = v.reshape(b,c,h*w)
243
+ v = rearrange(v, "b c t h w -> (b t) c (h w)") # bcl
244
+
245
+ # attend to values
246
+ w_ = w_.permute(0, 2, 1) # bll
247
+ h_ = torch.bmm(v, w_) # bcl
248
+
249
+ # h_ = h_.reshape(b,c,h,w)
250
+ h_ = rearrange(h_, "(b t) c (h w) -> b c t h w", b=b, h=h)
251
+
252
+ h_ = self.proj_out(h_)
253
+
254
+ return x + h_
255
+
256
+ class MultiHeadAttention3D(nn.Module):
257
+ def __init__(self, in_channels, num_heads=8):
258
+ super().__init__()
259
+ self.in_channels = in_channels
260
+ self.num_heads = num_heads
261
+ self.head_dim = in_channels // num_heads
262
+
263
+ assert self.head_dim * num_heads == in_channels, "in_channels must be divisible by num_heads"
264
+
265
+ self.norm = nn.LayerNorm(in_channels)
266
+ self.q_linear = nn.Linear(in_channels, in_channels)
267
+ self.k_linear = nn.Linear(in_channels, in_channels)
268
+ self.v_linear = nn.Linear(in_channels, in_channels)
269
+ self.proj_out = nn.Linear(in_channels, in_channels)
270
+
271
+ def forward(self, x):
272
+ b, c, t, h, w = x.shape
273
+ #print(x.shape)
274
+ # Normalize and reshape input
275
+ h_ = rearrange(x, "b c t h w -> (b t) (h w) c")
276
+ h_ = self.norm(h_)
277
+
278
+ # Linear projections
279
+ q = self.q_linear(h_)
280
+ k = self.k_linear(h_)
281
+ v = self.v_linear(h_)
282
+
283
+ # Reshape to multi-head
284
+ q = rearrange(q, "b l (h d) -> b h l d", h=self.num_heads)
285
+ k = rearrange(k, "b l (h d) -> b h l d", h=self.num_heads)
286
+ v = rearrange(v, "b l (h d) -> b h l d", h=self.num_heads)
287
+
288
+ # Scaled Dot-Product Attention
289
+ scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
290
+ attn = F.softmax(scores, dim=-1)
291
+
292
+ # Apply attention to values
293
+ out = torch.matmul(attn, v)
294
+ out = rearrange(out, "b h l d -> b l (h d)")
295
+
296
+ # Project back to original dimension
297
+ out = self.proj_out(out)
298
+
299
+ # Reshape back to original shape
300
+ out = rearrange(out, "(b t) (h w) c -> b c t h w", b=b, h=h, t=t)
301
+ #print(out.shape)
302
+ return x + out
siglip_encoder.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from typing import Optional, Tuple, Union, Dict
5
+ from PIL import Image
6
+ from functools import partial, reduce
7
+ from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
8
+
9
+ from .base_encoder import BaseVisionTower
10
+ import torch.distributed as dist
11
+ # --data_path /share/shuyan/video_traindata/anno/\{cinepine_order\}.json \
12
+ # --image_folder /share/shuyan/video_traindata/Bunny-v1_0-data/finetune/images \
13
+ # --video_folder /share/shuyan/video_traindata \
14
+ def rank0_print(*args):
15
+ if dist.is_initialized():
16
+ if dist.get_rank() == 0:
17
+ print(f"Rank {dist.get_rank()}: ", *args)
18
+ else:
19
+ print(*args)
20
+
21
+
22
+ from transformers.image_processing_utils import BatchFeature, get_size_dict
23
+ from transformers.image_transforms import (
24
+ convert_to_rgb,
25
+ normalize,
26
+ rescale,
27
+ resize,
28
+ to_channel_dimension_format,
29
+ )
30
+ from transformers.image_utils import (
31
+ ChannelDimension,
32
+ PILImageResampling,
33
+ to_numpy_array,
34
+ )
35
+ class SigLipImageProcessor:
36
+ def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST):
37
+ crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
38
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
39
+
40
+ self.image_mean = image_mean
41
+ self.image_std = image_std
42
+ self.size = size
43
+ self.resample = resample
44
+ self.rescale_factor = rescale_factor
45
+ self.data_format = data_format
46
+ self.crop_size = crop_size
47
+
48
+ def preprocess(self, images, return_tensors):
49
+ if isinstance(images, Image.Image):
50
+ images = [images]
51
+ else:
52
+ # to adapt video data
53
+ images = [to_numpy_array(image) for image in images]
54
+ assert isinstance(images, list)
55
+
56
+ transforms = [
57
+ convert_to_rgb,
58
+ to_numpy_array,
59
+ partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
60
+ partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
61
+ partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
62
+ partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
63
+ ]
64
+
65
+ images = reduce(lambda x, f: [*map(f, x)], transforms, images)
66
+
67
+ data = {"pixel_values": images}
68
+
69
+ return BatchFeature(data=data, tensor_type=return_tensors)
70
+
71
+ class SigLipVisionTower(BaseVisionTower):
72
+ def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False):
73
+ super(SigLipVisionTower, self).__init__(vision_tower_name, vision_tower_cfg, delay_load)
74
+
75
+ # model_path = "google/siglip-so400m-patch14-384"
76
+ # base_model_name, res, interp = model_path, 384, 576
77
+ # self.vision_tower_name = base_model_name
78
+ self.vision_tower_name, res, interp = vision_tower_name, 384, 576
79
+ self._image_size = res if res is not None else 512
80
+ self.unfreeze_mm_vision_tower = getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False)
81
+
82
+ if not delay_load:
83
+ rank0_print(f"Loading vision tower: {vision_tower_name}")
84
+ self.load_model()
85
+ elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False):
86
+ # TODO: better detector is needed.
87
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
88
+ self.load_model()
89
+ elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts:
90
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
91
+ self.load_model()
92
+ else:
93
+ self.cfg_only = self.config
94
+
95
+ def load_model(self, device_map=None):
96
+ self.vision_model = "siglip"
97
+ # clip_model, processor = create_model_from_pretrained(self.vision_tower_name)
98
+ print(self.vision_tower_name)
99
+ self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
100
+
101
+ # self.vision_tower = clip_model.visual.trunk
102
+ self.vision_tower.output_tokens = True
103
+
104
+ self._hidden_size = self.vision_tower.config.hidden_size
105
+
106
+ self.image_processor = SigLipImageProcessor()
107
+
108
+ del self.vision_tower.vision_model.encoder.layers[-1:]
109
+ self.vision_tower.vision_model.head = nn.Identity()
110
+
111
+ self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
112
+
113
+ self.is_loaded = True
114
+
115
+ def _forward(self, images):
116
+ with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
117
+ image_features = self.vision_tower.forward(
118
+ images.to(device=self.device, dtype=self.dtype),
119
+ output_hidden_states=True,
120
+ ).hidden_states[-1]
121
+ return image_features
122
+ @property
123
+ def dummy_feature(self):
124
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
125
+
126
+ @property
127
+ def dtype(self):
128
+ for p in self.vision_tower.parameters():
129
+ return p.dtype
130
+
131
+ @property
132
+ def device(self):
133
+ for p in self.vision_tower.parameters():
134
+ return p.device
135
+
136
+ @property
137
+ def hidden_size(self):
138
+ return self.config.hidden_size
139
+
140
+ @property
141
+ def num_patches(self):
142
+ return (336 // 14) ** 2
143
+
144
+ @property
145
+ def num_patches_per_side(self):
146
+ #return self.config.image_size // self.config.patch_size
147
+ return 336//14
148
+ #return 27
149
+ # return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
150
+
151
+ @property
152
+ def image_size(self):
153
+ return 384
154
+ #return self.config.image_size
utils.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+ import numpy as np
7
+
8
+ import requests
9
+
10
+ from .constants import LOGDIR
11
+
12
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
13
+ moderation_msg = "I am sorry. Your input may violate our content moderation guidelines. Please avoid using harmful or offensive content."
14
+
15
+ handler = None
16
+
17
+ import torch.distributed as dist
18
+
19
+ try:
20
+ import av
21
+ except ImportError:
22
+ print("Please install pyav to use video processing functions.")
23
+
24
+
25
+ def process_video_with_pyav(video_file, data_args):
26
+ container = av.open(video_file)
27
+ stream = container.streams.video[0]
28
+ total_frame_num = stream.frames
29
+ avg_fps = round(stream.average_rate / data_args.video_fps)
30
+ frame_idx = [i for i in range(0, total_frame_num, avg_fps)]
31
+ if data_args.frames_upbound > 0:
32
+ if len(frame_idx) > data_args.frames_upbound:
33
+ uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int)
34
+ frame_idx = uniform_sampled_frames.tolist()
35
+
36
+ video_frames = []
37
+ for index, frame in enumerate(container.decode(video=0)):
38
+ if index in frame_idx:
39
+ video_frames.append(frame.to_rgb().to_ndarray())
40
+ if len(video_frames) == len(frame_idx): # Stop decoding once we have all needed frames
41
+ break
42
+
43
+ video = np.stack(video_frames)
44
+ return video
45
+
46
+
47
+ def rank0_print(*args):
48
+ if dist.is_initialized():
49
+ if dist.get_rank() == 0:
50
+ print(f"Rank {dist.get_rank()}: ", *args)
51
+ else:
52
+ print(*args)
53
+
54
+
55
+ def build_logger(logger_name, logger_filename):
56
+ global handler
57
+
58
+ formatter = logging.Formatter(
59
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
60
+ datefmt="%Y-%m-%d %H:%M:%S",
61
+ )
62
+
63
+ # Set the format of root handlers
64
+ if not logging.getLogger().handlers:
65
+ logging.basicConfig(level=logging.INFO)
66
+ logging.getLogger().handlers[0].setFormatter(formatter)
67
+
68
+ # Redirect stdout and stderr to loggers
69
+ stdout_logger = logging.getLogger("stdout")
70
+ stdout_logger.setLevel(logging.INFO)
71
+ sl = StreamToLogger(stdout_logger, logging.INFO)
72
+ sys.stdout = sl
73
+
74
+ stderr_logger = logging.getLogger("stderr")
75
+ stderr_logger.setLevel(logging.ERROR)
76
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
77
+ sys.stderr = sl
78
+
79
+ # Get logger
80
+ logger = logging.getLogger(logger_name)
81
+ logger.setLevel(logging.INFO)
82
+
83
+ # Add a file handler for all loggers
84
+ if handler is None:
85
+ os.makedirs(LOGDIR, exist_ok=True)
86
+ filename = os.path.join(LOGDIR, logger_filename)
87
+ handler = logging.handlers.TimedRotatingFileHandler(filename, when="D", utc=True)
88
+ handler.setFormatter(formatter)
89
+
90
+ for name, item in logging.root.manager.loggerDict.items():
91
+ if isinstance(item, logging.Logger):
92
+ item.addHandler(handler)
93
+
94
+ return logger
95
+
96
+
97
+ class StreamToLogger(object):
98
+ """
99
+ Fake file-like stream object that redirects writes to a logger instance.
100
+ """
101
+
102
+ def __init__(self, logger, log_level=logging.INFO):
103
+ self.terminal = sys.stdout
104
+ self.logger = logger
105
+ self.log_level = log_level
106
+ self.linebuf = ""
107
+
108
+ def __getattr__(self, attr):
109
+ return getattr(self.terminal, attr)
110
+
111
+ def write(self, buf):
112
+ temp_linebuf = self.linebuf + buf
113
+ self.linebuf = ""
114
+ for line in temp_linebuf.splitlines(True):
115
+ # From the io.TextIOWrapper docs:
116
+ # On output, if newline is None, any '\n' characters written
117
+ # are translated to the system default line separator.
118
+ # By default sys.stdout.write() expects '\n' newlines and then
119
+ # translates them so this is still cross platform.
120
+ if line[-1] == "\n":
121
+ self.logger.log(self.log_level, line.rstrip())
122
+ else:
123
+ self.linebuf += line
124
+
125
+ def flush(self):
126
+ if self.linebuf != "":
127
+ self.logger.log(self.log_level, self.linebuf.rstrip())
128
+ self.linebuf = ""
129
+
130
+
131
+ def disable_torch_init():
132
+ """
133
+ Disable the redundant torch default initialization to accelerate model creation.
134
+ """
135
+ import torch
136
+
137
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
138
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
139
+
140
+
141
+ def violates_moderation(text):
142
+ """
143
+ Check whether the text violates OpenAI moderation API.
144
+ """
145
+ url = "https://api.openai.com/v1/moderations"
146
+ headers = {"Content-Type": "application/json", "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
147
+ text = text.replace("\n", "")
148
+ data = "{" + '"input": ' + f'"{text}"' + "}"
149
+ data = data.encode("utf-8")
150
+ try:
151
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
152
+ flagged = ret.json()["results"][0]["flagged"]
153
+ except requests.exceptions.RequestException as e:
154
+ print(f"######################### Moderation Error: {e} #########################")
155
+ flagged = False
156
+ except KeyError as e:
157
+ print(f"######################### Moderation Error: {e} #########################")
158
+ flagged = False
159
+
160
+ return flagged
161
+
162
+
163
+ def pretty_print_semaphore(semaphore):
164
+ if semaphore is None:
165
+ return "None"
166
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
utils_encoder.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import numpy as np
3
+ import cv2, os
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+
8
+ def count_params(model, verbose=False):
9
+ total_params = sum(p.numel() for p in model.parameters())
10
+ if verbose:
11
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
12
+ return total_params
13
+
14
+
15
+ def check_istarget(name, para_list):
16
+ """
17
+ name: full name of source para
18
+ para_list: partial name of target para
19
+ """
20
+ istarget = False
21
+ for para in para_list:
22
+ if para in name:
23
+ return True
24
+ return istarget
25
+
26
+
27
+ def instantiate_from_config(config):
28
+ if not "target" in config:
29
+ if config == "__is_first_stage__":
30
+ return None
31
+ elif config == "__is_unconditional__":
32
+ return None
33
+ raise KeyError("Expected key `target` to instantiate.")
34
+
35
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
36
+
37
+
38
+ def get_obj_from_str(string, reload=False):
39
+ module, cls = string.rsplit(".", 1)
40
+ if reload:
41
+ module_imp = importlib.import_module(module)
42
+ importlib.reload(module_imp)
43
+ return getattr(importlib.import_module(module, package=None), cls)
44
+
45
+
46
+ def load_npz_from_dir(data_dir):
47
+ data = [
48
+ np.load(os.path.join(data_dir, data_name))["arr_0"]
49
+ for data_name in os.listdir(data_dir)
50
+ ]
51
+ data = np.concatenate(data, axis=0)
52
+ return data
53
+
54
+
55
+ def load_npz_from_paths(data_paths):
56
+ data = [np.load(data_path)["arr_0"] for data_path in data_paths]
57
+ data = np.concatenate(data, axis=0)
58
+ return data
59
+
60
+
61
+ def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None):
62
+ h, w = image.shape[:2]
63
+ if resize_short_edge is not None:
64
+ k = resize_short_edge / min(h, w)
65
+ else:
66
+ k = max_resolution / (h * w)
67
+ k = k**0.5
68
+ h = int(np.round(h * k / 64)) * 64
69
+ w = int(np.round(w * k / 64)) * 64
70
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
71
+ return image
72
+
73
+
74
+ def setup_dist(args):
75
+ if dist.is_initialized():
76
+ return
77
+ torch.cuda.set_device(args.local_rank)
78
+ torch.distributed.init_process_group("nccl", init_method="env://")
79
+
80
+
81
+ # adopted from
82
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
83
+ # and
84
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
85
+ # and
86
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
87
+ #
88
+ # thanks!
89
+
90
+ import torch.nn as nn
91
+ import math
92
+ from inspect import isfunction
93
+ import torch
94
+ from torch import nn
95
+ import torch.distributed as dist
96
+
97
+
98
+ def gather_data(data, return_np=True):
99
+ """gather data from multiple processes to one list"""
100
+ data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
101
+ dist.all_gather(data_list, data) # gather not supported with NCCL
102
+ if return_np:
103
+ data_list = [data.cpu().numpy() for data in data_list]
104
+ return data_list
105
+
106
+
107
+ def autocast(f):
108
+ def do_autocast(*args, **kwargs):
109
+ with torch.cuda.amp.autocast(
110
+ enabled=True,
111
+ dtype=torch.get_autocast_gpu_dtype(),
112
+ cache_enabled=torch.is_autocast_cache_enabled(),
113
+ ):
114
+ return f(*args, **kwargs)
115
+
116
+ return do_autocast
117
+
118
+
119
+ def extract_into_tensor(a, t, x_shape):
120
+ b, *_ = t.shape
121
+ out = a.gather(-1, t)
122
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
123
+
124
+
125
+ def noise_like(shape, device, repeat=False):
126
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
127
+ shape[0], *((1,) * (len(shape) - 1))
128
+ )
129
+ noise = lambda: torch.randn(shape, device=device)
130
+ return repeat_noise() if repeat else noise()
131
+
132
+
133
+ def default(val, d):
134
+ if exists(val):
135
+ return val
136
+ return d() if isfunction(d) else d
137
+
138
+
139
+ def exists(val):
140
+ return val is not None
141
+
142
+
143
+ def identity(*args, **kwargs):
144
+ return nn.Identity()
145
+
146
+
147
+ def uniq(arr):
148
+ return {el: True for el in arr}.keys()
149
+
150
+
151
+ def mean_flat(tensor):
152
+ """
153
+ Take the mean over all non-batch dimensions.
154
+ """
155
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
156
+
157
+
158
+ def ismap(x):
159
+ if not isinstance(x, torch.Tensor):
160
+ return False
161
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
162
+
163
+
164
+ def isimage(x):
165
+ if not isinstance(x, torch.Tensor):
166
+ return False
167
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
168
+
169
+
170
+ def max_neg_value(t):
171
+ return -torch.finfo(t.dtype).max
172
+
173
+
174
+ def shape_to_str(x):
175
+ shape_str = "x".join([str(x) for x in x.shape])
176
+ return shape_str
177
+
178
+
179
+ def init_(tensor):
180
+ dim = tensor.shape[-1]
181
+ std = 1 / math.sqrt(dim)
182
+ tensor.uniform_(-std, std)
183
+ return tensor
184
+
185
+
186
+ # ckpt = torch.utils.checkpoint.checkpoint
187
+
188
+
189
+ # def checkpoint(func, inputs, params, flag):
190
+ # """
191
+ # Evaluate a function without caching intermediate activations, allowing for
192
+ # reduced memory at the expense of extra compute in the backward pass.
193
+ # :param func: the function to evaluate.
194
+ # :param inputs: the argument sequence to pass to `func`.
195
+ # :param params: a sequence of parameters `func` depends on but does not
196
+ # explicitly take as arguments.
197
+ # :param flag: if False, disable gradient checkpointing.
198
+ # """
199
+ # if flag:
200
+ # return ckpt(func, *inputs)
201
+ # else:
202
+ # return func(*inputs)
203
+
204
+
205
+ def disabled_train(self, mode=True):
206
+ """Overwrite model.train with this function to make sure train/eval mode
207
+ does not change anymore."""
208
+ return self
209
+
210
+
211
+ def zero_module(module):
212
+ """
213
+ Zero out the parameters of a module and return it.
214
+ """
215
+ for p in module.parameters():
216
+ p.detach().zero_()
217
+ return module
218
+
219
+
220
+ def scale_module(module, scale):
221
+ """
222
+ Scale the parameters of a module and return it.
223
+ """
224
+ for p in module.parameters():
225
+ p.detach().mul_(scale)
226
+ return module
227
+
228
+
229
+ def conv_nd(dims, *args, **kwargs):
230
+ """
231
+ Create a 1D, 2D, or 3D convolution module.
232
+ """
233
+ if dims == 1:
234
+ return nn.Conv1d(*args, **kwargs)
235
+ elif dims == 2:
236
+ return nn.Conv2d(*args, **kwargs)
237
+ elif dims == 3:
238
+ return nn.Conv3d(*args, **kwargs)
239
+ raise ValueError(f"unsupported dimensions: {dims}")
240
+
241
+
242
+ def linear(*args, **kwargs):
243
+ """
244
+ Create a linear module.
245
+ """
246
+ return nn.Linear(*args, **kwargs)
247
+
248
+
249
+ def avg_pool_nd(dims, *args, **kwargs):
250
+ """
251
+ Create a 1D, 2D, or 3D average pooling module.
252
+ """
253
+ if dims == 1:
254
+ return nn.AvgPool1d(*args, **kwargs)
255
+ elif dims == 2:
256
+ return nn.AvgPool2d(*args, **kwargs)
257
+ elif dims == 3:
258
+ return nn.AvgPool3d(*args, **kwargs)
259
+ raise ValueError(f"unsupported dimensions: {dims}")
260
+
261
+
262
+ def nonlinearity(type="silu"):
263
+ if type == "silu":
264
+ return nn.SiLU()
265
+ elif type == "leaky_relu":
266
+ return nn.LeakyReLU()
267
+
268
+
269
+ class GroupNormSpecific(nn.GroupNorm):
270
+ def forward(self, x):
271
+ if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
272
+ return super().forward(x).type(x.dtype)
273
+ else:
274
+ return super().forward(x.float()).type(x.dtype)
275
+
276
+
277
+ def normalization(channels, num_groups=32):
278
+ """
279
+ Make a standard normalization layer.
280
+ :param channels: number of input channels.
281
+ :return: an nn.Module for normalization.
282
+ """
283
+ return GroupNormSpecific(num_groups, channels)
284
+
285
+
286
+ class HybridConditioner(nn.Module):
287
+
288
+ def __init__(self, c_concat_config, c_crossattn_config):
289
+ super().__init__()
290
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
291
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
292
+
293
+ def forward(self, c_concat, c_crossattn):
294
+ c_concat = self.concat_conditioner(c_concat)
295
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
296
+ return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}