BAAI
/

3v324v23 commited on
Commit
5644dea
·
1 Parent(s): d1ed198
Files changed (37) hide show
  1. attention_temporal_videoae.py +0 -1314
  2. base_encoder.py +0 -68
  3. builder.py +0 -17
  4. config.json +4 -2
  5. configuration_qwen2.py +0 -169
  6. llava_arch.py +30 -105
  7. llava_qwen.py +673 -9
  8. modeling_qwen2.py +2 -0
  9. multimodal_encoder/.ipynb_checkpoints/base_encoder-checkpoint.py +0 -68
  10. multimodal_encoder/.ipynb_checkpoints/builder-checkpoint.py +0 -29
  11. multimodal_encoder/.ipynb_checkpoints/clip_encoder-checkpoint.py +0 -179
  12. multimodal_encoder/.ipynb_checkpoints/siglip_encoder-checkpoint.py +0 -151
  13. multimodal_encoder/__pycache__/base_encoder.cpython-310.pyc +0 -0
  14. multimodal_encoder/__pycache__/builder.cpython-310.pyc +0 -0
  15. multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc +0 -0
  16. multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc +0 -0
  17. multimodal_encoder/base_encoder.py +0 -68
  18. multimodal_encoder/builder.py +0 -20
  19. multimodal_projector/__pycache__/builder.cpython-310.pyc +0 -0
  20. multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc +0 -0
  21. multimodal_projector/pooler_projector.py +0 -33
  22. multimodal_resampler/__pycache__/builder.cpython-310.pyc +0 -0
  23. multimodal_resampler/__pycache__/masked_drop.cpython-310.pyc +0 -0
  24. multimodal_resampler/__pycache__/perceiver.cpython-310.pyc +0 -0
  25. multimodal_resampler/__pycache__/qformer.cpython-310.pyc +0 -0
  26. multimodal_resampler/__pycache__/spatial_pool.cpython-310.pyc +0 -0
  27. multimodal_resampler/builder.py +0 -34
  28. multimodal_resampler/masked_drop.py +0 -80
  29. multimodal_resampler/perceiver.py +0 -155
  30. multimodal_resampler/qformer.py +0 -1160
  31. sae.py +1434 -10
  32. sae_utils.py +0 -302
  33. siglip_encoder.py +0 -154
  34. utils_encoder.py +0 -296
  35. multimodal_projector/builder.py → vision_projector_builder.py +29 -1
  36. multimodal_resampler/spatial_pool.py → vision_resampler_builder.py +23 -0
  37. multimodal_encoder/siglip_encoder.py → vision_tower_builder.py +92 -17
attention_temporal_videoae.py DELETED
@@ -1,1314 +0,0 @@
1
- from inspect import isfunction
2
- import math
3
- import torch
4
- import torch as th
5
- import torch.nn.functional as F
6
- from torch import nn, einsum
7
- from einops import rearrange, repeat
8
- from typing import Optional, Any
9
-
10
- try:
11
- import xformers
12
- import xformers.ops
13
-
14
- XFORMERS_IS_AVAILBLE = True
15
- except:
16
- XFORMERS_IS_AVAILBLE = False
17
-
18
- from .utils_encoder import (
19
- conv_nd,
20
- zero_module,
21
- normalization,
22
- )
23
-
24
-
25
- def exists(val):
26
- return val is not None
27
-
28
-
29
- def uniq(arr):
30
- return {el: True for el in arr}.keys()
31
-
32
-
33
- def default(val, d):
34
- if exists(val):
35
- return val
36
- return d() if isfunction(d) else d
37
-
38
-
39
- def max_neg_value(t):
40
- return -torch.finfo(t.dtype).max
41
-
42
-
43
- def init_(tensor):
44
- dim = tensor.shape[-1]
45
- std = 1 / math.sqrt(dim)
46
- tensor.uniform_(-std, std)
47
- return tensor
48
-
49
-
50
- # feedforward
51
- class GEGLU(nn.Module):
52
- def __init__(self, dim_in, dim_out):
53
- super().__init__()
54
- self.proj = nn.Linear(dim_in, dim_out * 2)
55
-
56
- def forward(self, x):
57
- x, gate = self.proj(x).chunk(2, dim=-1)
58
- return x * F.gelu(gate)
59
-
60
-
61
- class FeedForward(nn.Module):
62
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
63
- super().__init__()
64
- inner_dim = int(dim * mult)
65
- dim_out = default(dim_out, dim)
66
- project_in = (
67
- nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
68
- if not glu
69
- else GEGLU(dim, inner_dim)
70
- )
71
-
72
- self.net = nn.Sequential(
73
- project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
74
- )
75
-
76
- def forward(self, x):
77
- return self.net(x)
78
-
79
-
80
- def zero_module(module):
81
- """
82
- Zero out the parameters of a module and return it.
83
- """
84
- for p in module.parameters():
85
- p.detach().zero_()
86
- return module
87
-
88
-
89
- def Normalize(in_channels, num_groups=32):
90
- return torch.nn.GroupNorm(
91
- num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
92
- )
93
-
94
-
95
- # ---------------------------------------------------------------------------------------------------
96
- class RelativePosition(nn.Module):
97
- """https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py"""
98
-
99
- def __init__(self, num_units, max_relative_position):
100
- super().__init__()
101
- self.num_units = num_units
102
- self.max_relative_position = max_relative_position
103
- self.embeddings_table = nn.Parameter(
104
- th.Tensor(max_relative_position * 2 + 1, num_units)
105
- )
106
- nn.init.xavier_uniform_(self.embeddings_table)
107
-
108
- def forward(self, length_q, length_k):
109
- device = self.embeddings_table.device
110
- range_vec_q = th.arange(length_q, device=device)
111
- range_vec_k = th.arange(length_k, device=device)
112
- distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
113
- distance_mat_clipped = th.clamp(
114
- distance_mat, -self.max_relative_position, self.max_relative_position
115
- )
116
- final_mat = distance_mat_clipped + self.max_relative_position
117
- # final_mat = th.LongTensor(final_mat).to(self.embeddings_table.device)
118
- # final_mat = th.tensor(final_mat, device=self.embeddings_table.device, dtype=torch.long)
119
- final_mat = final_mat.long()
120
- embeddings = self.embeddings_table[final_mat]
121
- return embeddings
122
-
123
-
124
- class TemporalCrossAttention(nn.Module):
125
- def __init__(
126
- self,
127
- query_dim,
128
- context_dim=None,
129
- heads=8,
130
- dim_head=64,
131
- dropout=0.0,
132
- temporal_length=None, # For relative positional representation and image-video joint training.
133
- image_length=None, # For image-video joint training.
134
- use_relative_position=False, # whether use relative positional representation in temporal attention.
135
- img_video_joint_train=False, # For image-video joint training.
136
- use_tempoal_causal_attn=False,
137
- bidirectional_causal_attn=False,
138
- tempoal_attn_type=None,
139
- joint_train_mode="same_batch",
140
- **kwargs,
141
- ):
142
- super().__init__()
143
- inner_dim = dim_head * heads
144
- context_dim = default(context_dim, query_dim)
145
- self.context_dim = context_dim
146
-
147
- self.scale = dim_head**-0.5
148
- self.heads = heads
149
- self.temporal_length = temporal_length
150
- self.use_relative_position = use_relative_position
151
- self.img_video_joint_train = img_video_joint_train
152
- self.bidirectional_causal_attn = bidirectional_causal_attn
153
- self.joint_train_mode = joint_train_mode
154
- assert joint_train_mode in ["same_batch", "diff_batch"]
155
- self.tempoal_attn_type = tempoal_attn_type
156
-
157
- if bidirectional_causal_attn:
158
- assert use_tempoal_causal_attn
159
- if tempoal_attn_type:
160
- assert tempoal_attn_type in ["sparse_causal", "sparse_causal_first"]
161
- assert not use_tempoal_causal_attn
162
- assert not (
163
- img_video_joint_train and (self.joint_train_mode == "same_batch")
164
- )
165
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
166
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
167
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
168
-
169
- assert not (
170
- img_video_joint_train
171
- and (self.joint_train_mode == "same_batch")
172
- and use_tempoal_causal_attn
173
- )
174
- if img_video_joint_train:
175
- if self.joint_train_mode == "same_batch":
176
- mask = torch.ones(
177
- [1, temporal_length + image_length, temporal_length + image_length]
178
- )
179
- # mask[:, image_length:, :] = 0
180
- # mask[:, :, image_length:] = 0
181
- mask[:, temporal_length:, :] = 0
182
- mask[:, :, temporal_length:] = 0
183
- self.mask = mask
184
- else:
185
- self.mask = None
186
- elif use_tempoal_causal_attn:
187
- # normal causal attn
188
- self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length]))
189
- elif tempoal_attn_type == "sparse_causal":
190
- # all frames interact with only the `prev` & self frame
191
- mask1 = torch.tril(
192
- torch.ones([1, temporal_length, temporal_length])
193
- ).bool() # true indicates keeping
194
- mask2 = torch.zeros(
195
- [1, temporal_length, temporal_length]
196
- ) # initialize to same shape with mask1
197
- mask2[:, 2:temporal_length, : temporal_length - 2] = torch.tril(
198
- torch.ones([1, temporal_length - 2, temporal_length - 2])
199
- )
200
- mask2 = (1 - mask2).bool() # false indicates masking
201
- self.mask = mask1 & mask2
202
- elif tempoal_attn_type == "sparse_causal_first":
203
- # all frames interact with only the `first` & self frame
204
- mask1 = torch.tril(
205
- torch.ones([1, temporal_length, temporal_length])
206
- ).bool() # true indicates keeping
207
- mask2 = torch.zeros([1, temporal_length, temporal_length])
208
- mask2[:, 2:temporal_length, 1 : temporal_length - 1] = torch.tril(
209
- torch.ones([1, temporal_length - 2, temporal_length - 2])
210
- )
211
- mask2 = (1 - mask2).bool() # false indicates masking
212
- self.mask = mask1 & mask2
213
- else:
214
- self.mask = None
215
-
216
- if use_relative_position:
217
- assert temporal_length is not None
218
- self.relative_position_k = RelativePosition(
219
- num_units=dim_head, max_relative_position=temporal_length
220
- )
221
- self.relative_position_v = RelativePosition(
222
- num_units=dim_head, max_relative_position=temporal_length
223
- )
224
-
225
- self.to_out = nn.Sequential(
226
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
227
- )
228
-
229
- nn.init.constant_(self.to_q.weight, 0)
230
- nn.init.constant_(self.to_k.weight, 0)
231
- nn.init.constant_(self.to_v.weight, 0)
232
- nn.init.constant_(self.to_out[0].weight, 0)
233
- nn.init.constant_(self.to_out[0].bias, 0)
234
-
235
- def forward(self, x, context=None, mask=None):
236
- # if context is None:
237
- # print(f'[Temp Attn] x={x.shape},context=None')
238
- # else:
239
- # print(f'[Temp Attn] x={x.shape},context={context.shape}')
240
-
241
- nh = self.heads
242
- out = x
243
- q = self.to_q(out)
244
- # if context is not None:
245
- # print(f'temporal context 1 ={context.shape}')
246
- # print(f'x={x.shape}')
247
- context = default(context, x)
248
- # print(f'temporal context 2 ={context.shape}')
249
- k = self.to_k(context)
250
- v = self.to_v(context)
251
- # print(f'q ={q.shape},k={k.shape}')
252
-
253
- q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=nh), (q, k, v))
254
- sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
255
-
256
- if self.use_relative_position:
257
- len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
258
- k2 = self.relative_position_k(len_q, len_k)
259
- sim2 = einsum("b t d, t s d -> b t s", q, k2) * self.scale # TODO check
260
- sim += sim2
261
- # print('mask',mask)
262
- if exists(self.mask):
263
- if mask is None:
264
- mask = self.mask.to(sim.device)
265
- else:
266
- mask = self.mask.to(sim.device).bool() & mask # .to(sim.device)
267
- else:
268
- mask = mask
269
- # if self.img_video_joint_train:
270
- # # process mask (make mask same shape with sim)
271
- # c, h, w = mask.shape
272
- # c, t, s = sim.shape
273
- # # assert(h == w and t == s),f"mask={mask.shape}, sim={sim.shape}, h={h}, w={w}, t={t}, s={s}"
274
-
275
- # if h > t:
276
- # mask = mask[:, :t, :]
277
- # elif h < t: # pad zeros to mask (no attention) only initial mask =1 area compute weights
278
- # mask_ = torch.zeros([c,t,w]).to(mask.device)
279
- # mask_[:, :h, :] = mask
280
- # mask = mask_
281
- # c, h, w = mask.shape
282
- # if w > s:
283
- # mask = mask[:, :, :s]
284
- # elif w < s: # pad zeros to mask
285
- # mask_ = torch.zeros([c,h,s]).to(mask.device)
286
- # mask_[:, :, :w] = mask
287
- # mask = mask_
288
-
289
- # max_neg_value = -torch.finfo(sim.dtype).max
290
- # sim = sim.float().masked_fill(mask == 0, max_neg_value)
291
- if mask is not None:
292
- max_neg_value = -1e9
293
- sim = sim + (1 - mask.float()) * max_neg_value # 1=masking,0=no masking
294
- # print('sim after masking: ', sim)
295
-
296
- # if torch.isnan(sim).any() or torch.isinf(sim).any() or (not sim.any()):
297
- # print(f'sim [after masking], isnan={torch.isnan(sim).any()}, isinf={torch.isinf(sim).any()}, allzero={not sim.any()}')
298
-
299
- attn = sim.softmax(dim=-1)
300
- # print('attn after softmax: ', attn)
301
- # if torch.isnan(attn).any() or torch.isinf(attn).any() or (not attn.any()):
302
- # print(f'attn [after softmax], isnan={torch.isnan(attn).any()}, isinf={torch.isinf(attn).any()}, allzero={not attn.any()}')
303
-
304
- # attn = torch.where(torch.isnan(attn), torch.full_like(attn,0), attn)
305
- # if torch.isinf(attn.detach()).any():
306
- # import pdb;pdb.set_trace()
307
- # if torch.isnan(attn.detach()).any():
308
- # import pdb;pdb.set_trace()
309
- out = einsum("b i j, b j d -> b i d", attn, v)
310
-
311
- if self.bidirectional_causal_attn:
312
- mask_reverse = torch.triu(
313
- torch.ones(
314
- [1, self.temporal_length, self.temporal_length], device=sim.device
315
- )
316
- )
317
- sim_reverse = sim.float().masked_fill(mask_reverse == 0, max_neg_value)
318
- attn_reverse = sim_reverse.softmax(dim=-1)
319
- out_reverse = einsum("b i j, b j d -> b i d", attn_reverse, v)
320
- out += out_reverse
321
-
322
- if self.use_relative_position:
323
- v2 = self.relative_position_v(len_q, len_v)
324
- out2 = einsum("b t s, t s d -> b t d", attn, v2) # TODO check
325
- out += out2 # TODO check:先add还是先merge head?先计算rpr,on split head之后的数据,然后再merge。
326
- out = rearrange(out, "(b h) n d -> b n (h d)", h=nh) # merge head
327
- return self.to_out(out)
328
-
329
-
330
- # ---------------------------------------------------------------------------------------------------
331
-
332
-
333
- class SpatialSelfAttention(nn.Module):
334
- def __init__(self, in_channels):
335
- super().__init__()
336
- self.in_channels = in_channels
337
-
338
- self.norm = Normalize(in_channels)
339
- self.q = torch.nn.Conv2d(
340
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
341
- )
342
- self.k = torch.nn.Conv2d(
343
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
344
- )
345
- self.v = torch.nn.Conv2d(
346
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
347
- )
348
- self.proj_out = torch.nn.Conv2d(
349
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
350
- )
351
-
352
- def forward(self, x):
353
- h_ = x
354
- h_ = self.norm(h_)
355
- q = self.q(h_)
356
- k = self.k(h_)
357
- v = self.v(h_)
358
-
359
- # compute attention
360
- b, c, h, w = q.shape
361
- q = rearrange(q, "b c h w -> b (h w) c")
362
- k = rearrange(k, "b c h w -> b c (h w)")
363
- w_ = torch.einsum("bij,bjk->bik", q, k)
364
-
365
- w_ = w_ * (int(c) ** (-0.5))
366
- w_ = torch.nn.functional.softmax(w_, dim=2)
367
-
368
- # attend to values
369
- v = rearrange(v, "b c h w -> b c (h w)")
370
- w_ = rearrange(w_, "b i j -> b j i")
371
- h_ = torch.einsum("bij,bjk->bik", v, w_)
372
- h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
373
- h_ = self.proj_out(h_)
374
-
375
- return x + h_
376
-
377
-
378
- class CrossAttention(nn.Module):
379
- def __init__(
380
- self,
381
- query_dim,
382
- context_dim=None,
383
- heads=8,
384
- dim_head=64,
385
- dropout=0.0,
386
- sa_shared_kv=False,
387
- shared_type="only_first",
388
- **kwargs,
389
- ):
390
- super().__init__()
391
- inner_dim = dim_head * heads
392
- context_dim = default(context_dim, query_dim)
393
- self.sa_shared_kv = sa_shared_kv
394
- assert shared_type in [
395
- "only_first",
396
- "all_frames",
397
- "first_and_prev",
398
- "only_prev",
399
- "full",
400
- "causal",
401
- "full_qkv",
402
- ]
403
- self.shared_type = shared_type
404
-
405
- self.scale = dim_head**-0.5
406
- self.heads = heads
407
- self.dim_head = dim_head
408
-
409
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
410
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
411
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
412
-
413
- self.to_out = nn.Sequential(
414
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
415
- )
416
- self.attention_op: Optional[Any] = None
417
-
418
- def forward(self, x, context=None, mask=None):
419
- h = self.heads
420
- b = x.shape[0]
421
-
422
- q = self.to_q(x)
423
- context = default(context, x)
424
- k = self.to_k(context)
425
- v = self.to_v(context)
426
- if self.sa_shared_kv:
427
- if self.shared_type == "only_first":
428
- k, v = map(
429
- lambda xx: rearrange(xx[0].unsqueeze(0), "b n c -> (b n) c")
430
- .unsqueeze(0)
431
- .repeat(b, 1, 1),
432
- (k, v),
433
- )
434
- else:
435
- raise NotImplementedError
436
-
437
- q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
438
-
439
- sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
440
-
441
- if exists(mask):
442
- mask = rearrange(mask, "b ... -> b (...)")
443
- max_neg_value = -torch.finfo(sim.dtype).max
444
- mask = repeat(mask, "b j -> (b h) () j", h=h)
445
- sim.masked_fill_(~mask, max_neg_value)
446
-
447
- # attention, what we cannot get enough of
448
- attn = sim.softmax(dim=-1)
449
-
450
- out = einsum("b i j, b j d -> b i d", attn, v)
451
- out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
452
- return self.to_out(out)
453
-
454
- def efficient_forward(self, x, context=None, mask=None):
455
- q = self.to_q(x)
456
- context = default(context, x)
457
- k = self.to_k(context)
458
- v = self.to_v(context)
459
-
460
- b, _, _ = q.shape
461
- q, k, v = map(
462
- lambda t: t.unsqueeze(3)
463
- .reshape(b, t.shape[1], self.heads, self.dim_head)
464
- .permute(0, 2, 1, 3)
465
- .reshape(b * self.heads, t.shape[1], self.dim_head)
466
- .contiguous(),
467
- (q, k, v),
468
- )
469
- # actually compute the attention, what we cannot get enough of
470
- out = xformers.ops.memory_efficient_attention(
471
- q, k, v, attn_bias=None, op=self.attention_op
472
- )
473
-
474
- if exists(mask):
475
- raise NotImplementedError
476
- out = (
477
- out.unsqueeze(0)
478
- .reshape(b, self.heads, out.shape[1], self.dim_head)
479
- .permute(0, 2, 1, 3)
480
- .reshape(b, out.shape[1], self.heads * self.dim_head)
481
- )
482
- return self.to_out(out)
483
-
484
-
485
- class VideoSpatialCrossAttention(CrossAttention):
486
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0):
487
- super().__init__(query_dim, context_dim, heads, dim_head, dropout)
488
-
489
- def forward(self, x, context=None, mask=None):
490
- b, c, t, h, w = x.shape
491
- if context is not None:
492
- context = context.repeat(t, 1, 1)
493
- x = super.forward(spatial_attn_reshape(x), context=context) + x
494
- return spatial_attn_reshape_back(x, b, h)
495
-
496
-
497
- # class BasicTransformerBlockST(nn.Module):
498
- # def __init__(
499
- # self,
500
- # # Spatial Stuff
501
- # dim,
502
- # n_heads,
503
- # d_head,
504
- # dropout=0.0,
505
- # context_dim=None,
506
- # gated_ff=True,
507
- # checkpoint=True,
508
- # # Temporal Stuff
509
- # temporal_length=None,
510
- # image_length=None,
511
- # use_relative_position=True,
512
- # img_video_joint_train=False,
513
- # cross_attn_on_tempoal=False,
514
- # temporal_crossattn_type="selfattn",
515
- # order="stst",
516
- # temporalcrossfirst=False,
517
- # temporal_context_dim=None,
518
- # split_stcontext=False,
519
- # local_spatial_temporal_attn=False,
520
- # window_size=2,
521
- # random_t=False,
522
- # **kwargs,
523
- # ):
524
- # super().__init__()
525
- # # Self attention
526
- # self.attn1 = CrossAttention(
527
- # query_dim=dim,
528
- # heads=n_heads,
529
- # dim_head=d_head,
530
- # dropout=dropout,
531
- # **kwargs,
532
- # )
533
- # self.attn2 = CrossAttention(
534
- # query_dim=dim,
535
- # context_dim=context_dim,
536
- # heads=n_heads,
537
- # dim_head=d_head,
538
- # dropout=dropout,
539
- # **kwargs,
540
- # )
541
- # if XFORMERS_IS_AVAILBLE:
542
- # self.attn1.forward = self.attn1.efficient_forward
543
- # self.attn2.forward = self.attn2.efficient_forward
544
-
545
- # self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
546
- # # cross attention if context is not None
547
-
548
- # self.norm1 = nn.LayerNorm(dim)
549
- # self.norm2 = nn.LayerNorm(dim)
550
- # self.norm3 = nn.LayerNorm(dim)
551
- # self.checkpoint = checkpoint
552
- # self.order = order
553
- # assert self.order in ["stst", "sstt", "st_parallel"]
554
- # self.temporalcrossfirst = temporalcrossfirst
555
- # self.split_stcontext = split_stcontext
556
- # self.local_spatial_temporal_attn = local_spatial_temporal_attn
557
- # if self.local_spatial_temporal_attn:
558
- # assert self.order == "stst"
559
- # assert self.order == "stst"
560
- # self.window_size = window_size
561
- # if not split_stcontext:
562
- # temporal_context_dim = context_dim
563
- # # Temporal attention
564
- # assert temporal_crossattn_type in ["selfattn", "crossattn", "skip"]
565
- # self.temporal_crossattn_type = temporal_crossattn_type
566
- # self.attn1_tmp = TemporalCrossAttention(
567
- # query_dim=dim,
568
- # heads=n_heads,
569
- # dim_head=d_head,
570
- # dropout=dropout,
571
- # temporal_length=temporal_length,
572
- # image_length=image_length,
573
- # use_relative_position=use_relative_position,
574
- # img_video_joint_train=img_video_joint_train,
575
- # **kwargs,
576
- # )
577
- # self.attn2_tmp = TemporalCrossAttention(
578
- # query_dim=dim,
579
- # heads=n_heads,
580
- # dim_head=d_head,
581
- # dropout=dropout,
582
- # # cross attn
583
- # context_dim=(
584
- # temporal_context_dim if temporal_crossattn_type == "crossattn" else None
585
- # ),
586
- # # temporal attn
587
- # temporal_length=temporal_length,
588
- # image_length=image_length,
589
- # use_relative_position=use_relative_position,
590
- # img_video_joint_train=img_video_joint_train,
591
- # **kwargs,
592
- # )
593
- # self.norm4 = nn.LayerNorm(dim)
594
- # self.norm5 = nn.LayerNorm(dim)
595
- # self.random_t = random_t
596
- # # self.norm1_tmp = nn.LayerNorm(dim)
597
- # # self.norm2_tmp = nn.LayerNorm(dim)
598
-
599
- # ##############################################################################################################################################
600
- # def forward(
601
- # self,
602
- # x,
603
- # context=None,
604
- # temporal_context=None,
605
- # no_temporal_attn=None,
606
- # attn_mask=None,
607
- # **kwargs,
608
- # ):
609
- # # print(f'no_temporal_attn={no_temporal_attn}')
610
-
611
- # if not self.split_stcontext:
612
- # # st cross attention use the same context vector
613
- # temporal_context = context.detach().clone()
614
-
615
- # if context is None and temporal_context is None:
616
- # # self-attention models
617
- # if no_temporal_attn:
618
- # raise NotImplementedError
619
- # return checkpoint(
620
- # self._forward_nocontext, (x), self.parameters(), self.checkpoint
621
- # )
622
- # else:
623
- # # cross-attention models
624
- # if no_temporal_attn:
625
- # forward_func = self._forward_no_temporal_attn
626
- # else:
627
- # forward_func = self._forward
628
- # inputs = (
629
- # (x, context, temporal_context)
630
- # if temporal_context is not None
631
- # else (x, context)
632
- # )
633
- # return checkpoint(forward_func, inputs, self.parameters(), self.checkpoint)
634
- # # if attn_mask is not None:
635
- # # return checkpoint(self._forward, (x, context, temporal_context, attn_mask), self.parameters(), self.checkpoint)
636
- # # return checkpoint(self._forward, (x, context, temporal_context), self.parameters(), self.checkpoint)
637
-
638
- # def _forward(
639
- # self,
640
- # x,
641
- # context=None,
642
- # temporal_context=None,
643
- # mask=None,
644
- # no_temporal_attn=None,
645
- # ):
646
- # assert x.dim() == 5, f"x shape = {x.shape}"
647
- # b, c, t, h, w = x.shape
648
-
649
- # if self.order in ["stst", "sstt"]:
650
- # x = self._st_cross_attn(
651
- # x,
652
- # context,
653
- # temporal_context=temporal_context,
654
- # order=self.order,
655
- # mask=mask,
656
- # ) # no_temporal_attn=no_temporal_attn,
657
- # elif self.order == "st_parallel":
658
- # x = self._st_cross_attn_parallel(
659
- # x,
660
- # context,
661
- # temporal_context=temporal_context,
662
- # order=self.order,
663
- # ) # no_temporal_attn=no_temporal_attn,
664
- # else:
665
- # raise NotImplementedError
666
-
667
- # x = self.ff(self.norm3(x)) + x
668
- # if (no_temporal_attn is None) or (not no_temporal_attn):
669
- # x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
670
- # elif no_temporal_attn:
671
- # x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
672
- # return x
673
-
674
- # def _forward_no_temporal_attn(
675
- # self,
676
- # x,
677
- # context=None,
678
- # temporal_context=None,
679
- # ):
680
- # # temporary implementation :(
681
- # # because checkpoint does not support non-tensor inputs currently.
682
- # assert x.dim() == 5, f"x shape = {x.shape}"
683
- # b, c, t, h, w = x.shape
684
-
685
- # if self.order in ["stst", "sstt"]:
686
- # # x = self._st_cross_attn(x, context, temporal_context=temporal_context, order=self.order, no_temporal_attn=True,)
687
- # # mask = torch.zeros([1, t, t], device=x.device).bool() if context is None else torch.zeros([1, context.shape[1], t], device=x.device).bool()
688
- # mask = torch.zeros([1, t, t], device=x.device).bool()
689
- # x = self._st_cross_attn(
690
- # x,
691
- # context,
692
- # temporal_context=temporal_context,
693
- # order=self.order,
694
- # mask=mask,
695
- # )
696
- # elif self.order == "st_parallel":
697
- # x = self._st_cross_attn_parallel(
698
- # x,
699
- # context,
700
- # temporal_context=temporal_context,
701
- # order=self.order,
702
- # no_temporal_attn=True,
703
- # )
704
- # else:
705
- # raise NotImplementedError
706
-
707
- # x = self.ff(self.norm3(x)) + x
708
- # x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
709
- # # x = rearrange(x, '(b t) (h w) c -> b c t h w', b=b,h=h,w=w) # 3d -> 5d
710
- # return x
711
-
712
- # def _forward_nocontext(self, x, no_temporal_attn=None):
713
- # assert x.dim() == 5, f"x shape = {x.shape}"
714
- # b, c, t, h, w = x.shape
715
-
716
- # if self.order in ["stst", "sstt"]:
717
- # x = self._st_cross_attn(
718
- # x, order=self.order, no_temporal_attn=no_temporal_attn
719
- # )
720
- # elif self.order == "st_parallel":
721
- # x = self._st_cross_attn_parallel(
722
- # x, order=self.order, no_temporal_attn=no_temporal_attn
723
- # )
724
- # else:
725
- # raise NotImplementedError
726
-
727
- # x = self.ff(self.norm3(x)) + x
728
- # x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
729
-
730
- # return x
731
-
732
- # ##############################################################################################################################################
733
-
734
- # def _st_cross_attn(
735
- # self, x, context=None, temporal_context=None, order="stst", mask=None
736
- # ): # no_temporal_attn=None,
737
- # b, c, t, h, w = x.shape
738
- # # if context is not None:
739
- # # print(f'[_st_cross_attn input] x={x.shape}, context={context.shape}')
740
- # # else:
741
- # # print(f'[_st_cross_attn input] x={x.shape}')
742
-
743
- # if order == "stst":
744
- # # spatial self attention
745
- # x = rearrange(x, "b c t h w -> (b t) (h w) c")
746
- # # print(f'before attn1,x={x.shape}')
747
-
748
- # x = self.attn1(self.norm1(x)) + x
749
- # x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
750
-
751
- # # temporal self attention
752
- # # if (no_temporal_attn is None) or (not no_temporal_attn):
753
- # if self.local_spatial_temporal_attn:
754
- # x = local_spatial_temporal_attn_reshape(x, window_size=self.window_size)
755
- # else:
756
- # x = rearrange(x, "b c t h w -> (b h w) t c")
757
- # x = self.attn1_tmp(self.norm4(x), mask=mask) + x
758
-
759
- # if self.local_spatial_temporal_attn:
760
- # x = local_spatial_temporal_attn_reshape_back(
761
- # x, window_size=self.window_size, b=b, h=h, w=w, t=t
762
- # )
763
- # else:
764
- # x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
765
-
766
- # # spatial cross attention
767
- # x = rearrange(x, "b c t h w -> (b t) (h w) c")
768
- # # print(f'before attn2, x={x.shape}')
769
- # # if context is not None:
770
- # # print(f'[before attn2] context={context.shape}')
771
- # if context is not None:
772
- # if self.random_t:
773
- # context_ = []
774
- # for i in range(context.shape[0]):
775
- # context_.append(context[i].unsqueeze(0).repeat(t, 1, 1))
776
- # context_ = torch.cat(context_, dim=0)
777
- # else:
778
- # if context.shape[0] == t: # img captions no_temporal_attn or
779
- # context_ = context
780
- # else:
781
- # # repeat conditions with t times
782
- # context_ = []
783
- # for i in range(context.shape[0]):
784
- # context_.append(context[i].unsqueeze(0).repeat(t, 1, 1))
785
- # context_ = torch.cat(context_, dim=0)
786
- # else:
787
- # context_ = None
788
-
789
- # # if context_ is not None:
790
- # # print(f'[before attn2] x={x.shape}, context_={context_.shape}')
791
- # # else:
792
- # # print(f'[before attn2] x={x.shape}')
793
-
794
- # x = self.attn2(self.norm2(x), context=context_) + x
795
-
796
- # # temporal cross attention
797
- # # if (no_temporal_attn is None) or (not no_temporal_attn):
798
- # x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
799
- # x = rearrange(x, "b c t h w -> (b h w) t c")
800
- # if self.temporal_crossattn_type == "crossattn":
801
- # # tmporal cross attention
802
- # if temporal_context is not None:
803
- # # print(f'STATTN context={context.shape}, temporal_context={temporal_context.shape}')
804
- # temporal_context = torch.cat(
805
- # [context, temporal_context], dim=1
806
- # ) # blc
807
- # # print(f'STATTN after concat temporal_context={temporal_context.shape}')
808
- # temporal_context = temporal_context.repeat(h * w, 1, 1)
809
- # # print(f'after repeat temporal_context={temporal_context.shape}')
810
- # else:
811
- # temporal_context = context[0:1, ...].repeat(h * w, 1, 1)
812
- # # print(f'STATTN after concat x={x.shape}')
813
- # x = (
814
- # self.attn2_tmp(self.norm5(x), context=temporal_context, mask=mask)
815
- # + x
816
- # )
817
- # elif self.temporal_crossattn_type == "selfattn":
818
- # # temporal self attention
819
- # x = self.attn2_tmp(self.norm5(x), context=None, mask=mask) + x
820
- # elif self.temporal_crossattn_type == "skip":
821
- # # no temporal cross and self attention
822
- # pass
823
- # else:
824
- # raise NotImplementedError
825
-
826
- # elif order == "sstt":
827
- # # spatial self attention
828
- # x = rearrange(x, "b c t h w -> (b t) (h w) c")
829
- # x = self.attn1(self.norm1(x)) + x
830
-
831
- # # spatial cross attention
832
- # context_ = context.repeat(t, 1, 1) if context is not None else None
833
- # x = self.attn2(self.norm2(x), context=context_) + x
834
- # x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
835
-
836
- # if (no_temporal_attn is None) or (not no_temporal_attn):
837
- # if self.temporalcrossfirst:
838
- # # temporal cross attention
839
- # if self.temporal_crossattn_type == "crossattn":
840
- # # if temporal_context is not None:
841
- # temporal_context = context.repeat(h * w, 1, 1)
842
- # x = (
843
- # self.attn2_tmp(
844
- # self.norm5(x), context=temporal_context, mask=mask
845
- # )
846
- # + x
847
- # )
848
- # elif self.temporal_crossattn_type == "selfattn":
849
- # x = self.attn2_tmp(self.norm5(x), context=None, mask=mask) + x
850
- # elif self.temporal_crossattn_type == "skip":
851
- # pass
852
- # else:
853
- # raise NotImplementedError
854
- # # temporal self attention
855
- # x = rearrange(x, "b c t h w -> (b h w) t c")
856
- # x = self.attn1_tmp(self.norm4(x), mask=mask) + x
857
- # else:
858
- # # temporal self attention
859
- # x = rearrange(x, "b c t h w -> (b h w) t c")
860
- # x = self.attn1_tmp(self.norm4(x), mask=mask) + x
861
- # # temporal cross attention
862
- # if self.temporal_crossattn_type == "crossattn":
863
- # if temporal_context is not None:
864
- # temporal_context = context.repeat(h * w, 1, 1)
865
- # x = (
866
- # self.attn2_tmp(
867
- # self.norm5(x), context=temporal_context, mask=mask
868
- # )
869
- # + x
870
- # )
871
- # elif self.temporal_crossattn_type == "selfattn":
872
- # x = self.attn2_tmp(self.norm5(x), context=None, mask=mask) + x
873
- # elif self.temporal_crossattn_type == "skip":
874
- # pass
875
- # else:
876
- # raise NotImplementedError
877
- # else:
878
- # raise NotImplementedError
879
-
880
- # return x
881
-
882
- # def _st_cross_attn_parallel(
883
- # self, x, context=None, temporal_context=None, order="sst", no_temporal_attn=None
884
- # ):
885
- # """order: x -> Self Attn -> Cross Attn -> attn_s
886
- # x -> Temp Self Attn -> attn_t
887
- # x' = x + attn_s + attn_t
888
- # """
889
- # if no_temporal_attn is not None:
890
- # raise NotImplementedError
891
-
892
- # B, C, T, H, W = x.shape
893
- # # spatial self attention
894
- # h = x
895
- # h = rearrange(h, "b c t h w -> (b t) (h w) c")
896
- # h = self.attn1(self.norm1(h)) + h
897
- # # spatial cross
898
- # # context_ = context.repeat(T, 1, 1) if context is not None else None
899
- # if context is not None:
900
- # context_ = []
901
- # for i in range(context.shape[0]):
902
- # context_.append(context[i].unsqueeze(0).repeat(T, 1, 1))
903
- # context_ = torch.cat(context_, dim=0)
904
- # else:
905
- # context_ = None
906
-
907
- # h = self.attn2(self.norm2(h), context=context_) + h
908
- # h = rearrange(h, "(b t) (h w) c -> b c t h w", b=B, h=H)
909
-
910
- # # temporal self
911
- # h2 = x
912
- # h2 = rearrange(h2, "b c t h w -> (b h w) t c")
913
- # h2 = self.attn1_tmp(self.norm4(h2)) # + h2
914
- # h2 = rearrange(h2, "(b h w) t c -> b c t h w", b=B, h=H, w=W)
915
- # out = h + h2
916
- # return rearrange(out, "b c t h w -> (b h w) t c")
917
-
918
- ##############################################################################################################################################
919
-
920
-
921
- def spatial_attn_reshape(x):
922
- return rearrange(x, "b c t h w -> (b t) (h w) c")
923
-
924
-
925
- def spatial_attn_reshape_back(x, b, h):
926
- return rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
927
-
928
-
929
- def temporal_attn_reshape(x):
930
- return rearrange(x, "b c t h w -> (b h w) t c")
931
-
932
-
933
- def temporal_attn_reshape_back(x, b, h, w):
934
- return rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w)
935
-
936
-
937
- def local_spatial_temporal_attn_reshape(x, window_size):
938
- B, C, T, H, W = x.shape
939
- NH = H // window_size
940
- NW = W // window_size
941
- # x = x.view(B, C, T, NH, window_size, NW, window_size)
942
- # tokens = x.permute(0, 1, 2, 3, 5, 4, 6).contiguous()
943
- # tokens = tokens.view(-1, window_size, window_size, C)
944
- x = rearrange(
945
- x,
946
- "b c t (nh wh) (nw ww) -> b c t nh wh nw ww",
947
- nh=NH,
948
- nw=NW,
949
- wh=window_size,
950
- ww=window_size,
951
- ).contiguous() # # B, C, T, NH, NW, window_size, window_size
952
- x = rearrange(
953
- x, "b c t nh wh nw ww -> (b nh nw) (t wh ww) c"
954
- ) # (B, NH, NW) (T, window_size, window_size) C
955
- return x
956
-
957
-
958
- def local_spatial_temporal_attn_reshape_back(x, window_size, b, h, w, t):
959
- B, L, C = x.shape
960
- NH = h // window_size
961
- NW = w // window_size
962
- x = rearrange(
963
- x,
964
- "(b nh nw) (t wh ww) c -> b c t nh wh nw ww",
965
- b=b,
966
- nh=NH,
967
- nw=NW,
968
- t=t,
969
- wh=window_size,
970
- ww=window_size,
971
- )
972
- x = rearrange(x, "b c t nh wh nw ww -> b c t (nh wh) (nw ww)")
973
- return x
974
-
975
-
976
- class SpatialTemporalTransformer(nn.Module):
977
- """
978
- Transformer block for video-like data (5D tensor).
979
- First, project the input (aka embedding) with NO reshape.
980
- Then apply standard transformer action.
981
- The 5D -> 3D reshape operation will be done in the specific attention module.
982
- """
983
-
984
- def __init__(
985
- self,
986
- in_channels,
987
- n_heads,
988
- d_head,
989
- depth=1,
990
- dropout=0.0,
991
- context_dim=None,
992
- # Temporal stuff
993
- temporal_length=None,
994
- image_length=None,
995
- use_relative_position=True,
996
- img_video_joint_train=False,
997
- cross_attn_on_tempoal=False,
998
- temporal_crossattn_type="selfattn",
999
- order="stst",
1000
- temporalcrossfirst=False,
1001
- split_stcontext=False,
1002
- temporal_context_dim=None,
1003
- **kwargs,
1004
- ):
1005
- super().__init__()
1006
-
1007
- self.in_channels = in_channels
1008
- inner_dim = n_heads * d_head
1009
-
1010
- self.norm = Normalize(in_channels)
1011
- self.proj_in = nn.Conv3d(
1012
- in_channels, inner_dim, kernel_size=1, stride=1, padding=0
1013
- )
1014
-
1015
- self.transformer_blocks = nn.ModuleList(
1016
- [
1017
- BasicTransformerBlockST(
1018
- inner_dim,
1019
- n_heads,
1020
- d_head,
1021
- dropout=dropout,
1022
- # cross attn
1023
- context_dim=context_dim,
1024
- # temporal attn
1025
- temporal_length=temporal_length,
1026
- image_length=image_length,
1027
- use_relative_position=use_relative_position,
1028
- img_video_joint_train=img_video_joint_train,
1029
- temporal_crossattn_type=temporal_crossattn_type,
1030
- order=order,
1031
- temporalcrossfirst=temporalcrossfirst,
1032
- split_stcontext=split_stcontext,
1033
- temporal_context_dim=temporal_context_dim,
1034
- **kwargs,
1035
- )
1036
- for d in range(depth)
1037
- ]
1038
- )
1039
-
1040
- self.proj_out = zero_module(
1041
- nn.Conv3d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
1042
- )
1043
-
1044
- def forward(self, x, context=None, temporal_context=None, **kwargs):
1045
- # note: if no context is given, cross-attention defaults to self-attention
1046
- assert x.dim() == 5, f"x shape = {x.shape}"
1047
- b, c, t, h, w = x.shape
1048
- x_in = x
1049
-
1050
- x = self.norm(x)
1051
- x = self.proj_in(x)
1052
-
1053
- for block in self.transformer_blocks:
1054
- x = block(x, context=context, temporal_context=temporal_context, **kwargs)
1055
-
1056
- x = self.proj_out(x)
1057
- return x + x_in
1058
-
1059
-
1060
- # ---------------------------------------------------------------------------------------------------
1061
-
1062
-
1063
- class STAttentionBlock2(nn.Module):
1064
- def __init__(
1065
- self,
1066
- channels,
1067
- num_heads=1,
1068
- num_head_channels=-1,
1069
- use_checkpoint=False, # not used, only used in ResBlock
1070
- use_new_attention_order=False, # QKVAttention or QKVAttentionLegacy
1071
- temporal_length=16, # used in relative positional representation.
1072
- image_length=8, # used for image-video joint training.
1073
- use_relative_position=False, # whether use relative positional representation in temporal attention.
1074
- img_video_joint_train=False,
1075
- # norm_type="groupnorm",
1076
- attn_norm_type="group",
1077
- use_tempoal_causal_attn=False,
1078
- ):
1079
- """
1080
- version 1: guided_diffusion implemented version
1081
- version 2: remove args input argument
1082
- """
1083
- super().__init__()
1084
-
1085
- if num_head_channels == -1:
1086
- self.num_heads = num_heads
1087
- else:
1088
- assert (
1089
- channels % num_head_channels == 0
1090
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
1091
- self.num_heads = channels // num_head_channels
1092
- self.use_checkpoint = use_checkpoint
1093
-
1094
- self.temporal_length = temporal_length
1095
- self.image_length = image_length
1096
- self.use_relative_position = use_relative_position
1097
- self.img_video_joint_train = img_video_joint_train
1098
- self.attn_norm_type = attn_norm_type
1099
- assert self.attn_norm_type in ["group", "no_norm"]
1100
- self.use_tempoal_causal_attn = use_tempoal_causal_attn
1101
-
1102
- if self.attn_norm_type == "group":
1103
- self.norm_s = normalization(channels)
1104
- self.norm_t = normalization(channels)
1105
-
1106
- self.qkv_s = conv_nd(1, channels, channels * 3, 1)
1107
- self.qkv_t = conv_nd(1, channels, channels * 3, 1)
1108
-
1109
- if self.img_video_joint_train:
1110
- mask = th.ones(
1111
- [1, temporal_length + image_length, temporal_length + image_length]
1112
- )
1113
- mask[:, temporal_length:, :] = 0
1114
- mask[:, :, temporal_length:] = 0
1115
- self.register_buffer("mask", mask)
1116
- else:
1117
- self.mask = None
1118
-
1119
- if use_new_attention_order:
1120
- # split qkv before split heads
1121
- self.attention_s = QKVAttention(self.num_heads)
1122
- self.attention_t = QKVAttention(self.num_heads)
1123
- else:
1124
- # split heads before split qkv
1125
- self.attention_s = QKVAttentionLegacy(self.num_heads)
1126
- self.attention_t = QKVAttentionLegacy(self.num_heads)
1127
-
1128
- if use_relative_position:
1129
- self.relative_position_k = RelativePosition(
1130
- num_units=channels // self.num_heads,
1131
- max_relative_position=temporal_length,
1132
- )
1133
- self.relative_position_v = RelativePosition(
1134
- num_units=channels // self.num_heads,
1135
- max_relative_position=temporal_length,
1136
- )
1137
-
1138
- self.proj_out_s = zero_module(
1139
- conv_nd(1, channels, channels, 1)
1140
- ) # conv_dim, in_channels, out_channels, kernel_size
1141
- self.proj_out_t = zero_module(
1142
- conv_nd(1, channels, channels, 1)
1143
- ) # conv_dim, in_channels, out_channels, kernel_size
1144
-
1145
- def forward(self, x, mask=None):
1146
- b, c, t, h, w = x.shape
1147
-
1148
- # spatial
1149
- out = rearrange(x, "b c t h w -> (b t) c (h w)")
1150
- if self.attn_norm_type == "no_norm":
1151
- qkv = self.qkv_s(out)
1152
- else:
1153
- qkv = self.qkv_s(self.norm_s(out))
1154
- out = self.attention_s(qkv)
1155
- out = self.proj_out_s(out)
1156
- out = rearrange(out, "(b t) c (h w) -> b c t h w", b=b, h=h)
1157
- x += out
1158
-
1159
- # temporal
1160
- out = rearrange(x, "b c t h w -> (b h w) c t")
1161
- if self.attn_norm_type == "no_norm":
1162
- qkv = self.qkv_t(out)
1163
- else:
1164
- qkv = self.qkv_t(self.norm_t(out))
1165
-
1166
- # relative positional embedding
1167
- if self.use_relative_position:
1168
- len_q = qkv.size()[-1]
1169
- len_k, len_v = len_q, len_q
1170
- k_rp = self.relative_position_k(len_q, len_k)
1171
- v_rp = self.relative_position_v(len_q, len_v) # [T,T,head_dim]
1172
- out = self.attention_t(
1173
- qkv,
1174
- rp=(k_rp, v_rp),
1175
- mask=self.mask,
1176
- use_tempoal_causal_attn=self.use_tempoal_causal_attn,
1177
- )
1178
- else:
1179
- out = self.attention_t(
1180
- qkv,
1181
- rp=None,
1182
- mask=self.mask,
1183
- use_tempoal_causal_attn=self.use_tempoal_causal_attn,
1184
- )
1185
-
1186
- out = self.proj_out_t(out)
1187
- out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w)
1188
-
1189
- return x + out
1190
-
1191
-
1192
- # ---------------------------------------------------------------------------------------------------------------
1193
-
1194
-
1195
- class QKVAttentionLegacy(nn.Module):
1196
- """
1197
- A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
1198
- """
1199
-
1200
- def __init__(self, n_heads):
1201
- super().__init__()
1202
- self.n_heads = n_heads
1203
-
1204
- def forward(self, qkv, rp=None, mask=None):
1205
- """
1206
- Apply QKV attention.
1207
-
1208
- :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
1209
- :return: an [N x (H * C) x T] tensor after attention.
1210
- """
1211
- if rp is not None or mask is not None:
1212
- raise NotImplementedError
1213
- bs, width, length = qkv.shape
1214
- assert width % (3 * self.n_heads) == 0
1215
- ch = width // (3 * self.n_heads)
1216
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
1217
- scale = 1 / math.sqrt(math.sqrt(ch))
1218
- weight = th.einsum(
1219
- "bct,bcs->bts", q * scale, k * scale
1220
- ) # More stable with f16 than dividing afterwards
1221
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
1222
- a = th.einsum("bts,bcs->bct", weight, v)
1223
- return a.reshape(bs, -1, length)
1224
-
1225
- @staticmethod
1226
- def count_flops(model, _x, y):
1227
- return count_flops_attn(model, _x, y)
1228
-
1229
-
1230
- # ---------------------------------------------------------------------------------------------------------------
1231
-
1232
-
1233
- class QKVAttention(nn.Module):
1234
- """
1235
- A module which performs QKV attention and splits in a different order.
1236
- """
1237
-
1238
- def __init__(self, n_heads):
1239
- super().__init__()
1240
- self.n_heads = n_heads
1241
-
1242
- def forward(self, qkv, rp=None, mask=None, use_tempoal_causal_attn=False):
1243
- """
1244
- Apply QKV attention.
1245
-
1246
- :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
1247
- :return: an [N x (H * C) x T] tensor after attention.
1248
- """
1249
- bs, width, length = qkv.shape
1250
- assert width % (3 * self.n_heads) == 0
1251
- ch = width // (3 * self.n_heads)
1252
- # print('qkv', qkv.size())
1253
- qkv=qkv.contiguous()
1254
- q, k, v = qkv.chunk(3, dim=1)
1255
- scale = 1 / math.sqrt(math.sqrt(ch))
1256
- # print('bs, self.n_heads, ch, length', bs, self.n_heads, ch, length)
1257
-
1258
- weight = th.einsum(
1259
- "bct,bcs->bts",
1260
- (q * scale).view(bs * self.n_heads, ch, length),
1261
- (k * scale).view(bs * self.n_heads, ch, length),
1262
- ) # More stable with f16 than dividing afterwards
1263
- # weight:[b,t,s] b=bs*n_heads*T
1264
-
1265
- if rp is not None:
1266
- k_rp, v_rp = rp # [length, length, head_dim] [8, 8, 48]
1267
- weight2 = th.einsum(
1268
- "bct,tsc->bst", (q * scale).view(bs * self.n_heads, ch, length), k_rp
1269
- )
1270
- weight += weight2
1271
-
1272
- if use_tempoal_causal_attn:
1273
- # weight = torch.tril(weight)
1274
- assert mask is None, f"Not implemented for merging two masks!"
1275
- mask = torch.tril(torch.ones(weight.shape))
1276
- else:
1277
- if mask is not None: # only keep upper-left matrix
1278
- # process mask
1279
- c, t, _ = weight.shape
1280
-
1281
- if mask.shape[-1] > t:
1282
- mask = mask[:, :t, :t]
1283
- elif mask.shape[-1] < t: # pad ones
1284
- mask_ = th.zeros([c, t, t]).to(mask.device)
1285
- t_ = mask.shape[-1]
1286
- mask_[:, :t_, :t_] = mask
1287
- mask = mask_
1288
- else:
1289
- assert (
1290
- weight.shape[-1] == mask.shape[-1]
1291
- ), f"weight={weight.shape}, mask={mask.shape}"
1292
-
1293
- if mask is not None:
1294
- INF = -1e8 # float('-inf')
1295
- weight = weight.float().masked_fill(mask == 0, INF)
1296
-
1297
- weight = F.softmax(weight.float(), dim=-1).type(
1298
- weight.dtype
1299
- ) # [256, 8, 8] [b, t, t] b=bs*n_heads*h*w,t=nframes
1300
- # weight = F.softmax(weight, dim=-1)#[256, 8, 8] [b, t, t] b=bs*n_heads*h*w,t=nframes
1301
- a = th.einsum(
1302
- "bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)
1303
- ) # [256, 48, 8] [b, head_dim, t]
1304
-
1305
- if rp is not None:
1306
- a2 = th.einsum("bts,tsc->btc", weight, v_rp).transpose(1, 2) # btc->bct
1307
- a += a2
1308
-
1309
- return a.reshape(bs, -1, length)
1310
-
1311
-
1312
- # ---------------------------------------------------------------------------------------------------------------
1313
-
1314
- # ---------------------------------------------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
base_encoder.py DELETED
@@ -1,68 +0,0 @@
1
- from abc import ABC, abstractmethod
2
-
3
- import torch
4
- import torch.nn as nn
5
-
6
-
7
- class BaseVisionTower(nn.Module):
8
- def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False):
9
- super().__init__()
10
-
11
- self.is_loaded = False
12
-
13
- self.vision_tower_name = vision_tower_name
14
- self.delay_load = delay_load
15
-
16
- @abstractmethod
17
- def load_model(self, device_map=None):
18
- raise NotImplementedError("Subclasses must implement load_model")
19
-
20
- @abstractmethod
21
- def _forward(self, images):
22
- raise NotImplementedError("Subclasses must implement forward")
23
-
24
- def forward(self, images):
25
- if type(images) is list:
26
- image_features = [self._forward(image.unsqueeze(0)) for image in images]
27
- else:
28
- image_features = self._forward(images)
29
-
30
- return image_features
31
-
32
- @property
33
- def dummy_feature(self):
34
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
35
-
36
- @property
37
- def dtype(self):
38
- # Dynamically infer the dtype from the first parameter, if not explicitly specified
39
- if hasattr(self.vision_tower, "dtype"):
40
- return self.vision_tower.dtype
41
- else:
42
- params = list(self.vision_tower.parameters())
43
- return (
44
- params[0].dtype if len(params) > 0 else torch.float32
45
- ) # Default to torch.float32 if no parameters
46
-
47
- @property
48
- def device(self):
49
- # Dynamically infer the device from the first parameter, if not explicitly specified
50
- if hasattr(self.vision_tower, "device"):
51
- return self.vision_tower.device
52
- else:
53
- params = list(self.vision_tower.parameters())
54
- return (
55
- params[0].device if len(params) > 0 else torch.device("cpu")
56
- ) # Default to CPU if no parameters
57
- @property
58
- def config(self):
59
- if self.is_loaded:
60
- return self.vision_tower.config
61
- else:
62
- return self.cfg_only
63
- @property
64
- def hidden_size(self):
65
- try:
66
- return self.config.hidden_size
67
- except:
68
- return self._hidden_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
builder.py DELETED
@@ -1,17 +0,0 @@
1
- import os
2
- from .siglip_encoder import SigLipVisionTower
3
-
4
-
5
- def build_vision_tower(vision_tower_cfg, **kwargs):
6
-
7
- vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
8
- is_absolute_path_exists = os.path.exists(vision_tower)
9
- use_s2 = getattr(vision_tower_cfg, "s2", False)
10
-
11
- #print(getattr(vision_tower_cfg, "vision_tower", None))
12
- return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
13
- if getattr(vision_tower_cfg, "vision_tower", None) and "siglip" in getattr(vision_tower_cfg, "vision_tower", None).lower():
14
- #print('*************\n')
15
- return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
16
-
17
- raise ValueError(f"Unknown vision tower: {vision_tower}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json CHANGED
@@ -4,7 +4,7 @@
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "llava_qwen.LlavaQwenConfig",
7
- "AutoModel": "llava_qwen.LlavaQwenForCausalLM"
8
  },
9
  "attention_dropout": 0.0,
10
  "bos_token_id": 151643,
@@ -202,5 +202,7 @@
202
  "use_pos_skipping": false,
203
  "use_sliding_window": false,
204
  "vision_tower_pretrained": null,
205
- "vocab_size": 152064
 
 
206
  }
 
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "llava_qwen.LlavaQwenConfig",
7
+ "AutoModelForCausalLM": "llava_qwen.LlavaQwenForCausalLM"
8
  },
9
  "attention_dropout": 0.0,
10
  "bos_token_id": 151643,
 
202
  "use_pos_skipping": false,
203
  "use_sliding_window": false,
204
  "vision_tower_pretrained": null,
205
+ "vocab_size": 152064,
206
+ "enable_chunk_prefill": false,
207
+ "prefill_config": {}
208
  }
configuration_qwen2.py DELETED
@@ -1,169 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ Qwen2 model configuration"""
16
-
17
- from transformers.configuration_utils import PretrainedConfig
18
- from transformers.utils import logging
19
-
20
-
21
- logger = logging.get_logger(__name__)
22
-
23
- QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24
- "Qwen/Qwen2-7B-beta": "https://huggingface.co/Qwen/Qwen2-7B-beta/resolve/main/config.json",
25
- }
26
-
27
-
28
- class Qwen2Config(PretrainedConfig):
29
- r"""
30
- This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
31
- Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
32
- with the defaults will yield a similar configuration to that of
33
- Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
34
-
35
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
- documentation from [`PretrainedConfig`] for more information.
37
-
38
-
39
- Args:
40
- vocab_size (`int`, *optional*, defaults to 151936):
41
- Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
42
- `inputs_ids` passed when calling [`Qwen2Model`]
43
- hidden_size (`int`, *optional*, defaults to 4096):
44
- Dimension of the hidden representations.
45
- intermediate_size (`int`, *optional*, defaults to 22016):
46
- Dimension of the MLP representations.
47
- num_hidden_layers (`int`, *optional*, defaults to 32):
48
- Number of hidden layers in the Transformer encoder.
49
- num_attention_heads (`int`, *optional*, defaults to 32):
50
- Number of attention heads for each attention layer in the Transformer encoder.
51
- num_key_value_heads (`int`, *optional*, defaults to 32):
52
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
53
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
54
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
55
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
56
- by meanpooling all the original heads within that group. For more details checkout [this
57
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
58
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59
- The non-linear activation function (function or string) in the decoder.
60
- max_position_embeddings (`int`, *optional*, defaults to 32768):
61
- The maximum sequence length that this model might ever be used with.
62
- initializer_range (`float`, *optional*, defaults to 0.02):
63
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
65
- The epsilon used by the rms normalization layers.
66
- use_cache (`bool`, *optional*, defaults to `True`):
67
- Whether or not the model should return the last key/values attentions (not used by all models). Only
68
- relevant if `config.is_decoder=True`.
69
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
70
- Whether the model's input and output word embeddings should be tied.
71
- rope_theta (`float`, *optional*, defaults to 10000.0):
72
- The base period of the RoPE embeddings.
73
- use_sliding_window (`bool`, *optional*, defaults to `False`):
74
- Whether to use sliding window attention.
75
- sliding_window (`int`, *optional*, defaults to 4096):
76
- Sliding window attention (SWA) window size. If not specified, will default to `4096`.
77
- max_window_layers (`int`, *optional*, defaults to 28):
78
- The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
79
- attention_dropout (`float`, *optional*, defaults to 0.0):
80
- The dropout ratio for the attention probabilities.
81
-
82
- ```python
83
- >>> from transformers import Qwen2Model, Qwen2Config
84
-
85
- >>> # Initializing a Qwen2 style configuration
86
- >>> configuration = Qwen2Config()
87
-
88
- >>> # Initializing a model from the Qwen2-7B style configuration
89
- >>> model = Qwen2Model(configuration)
90
-
91
- >>> # Accessing the model configuration
92
- >>> configuration = model.config
93
- ```"""
94
-
95
- model_type = "qwen2"
96
- keys_to_ignore_at_inference = ["past_key_values"]
97
-
98
- def __init__(
99
- self,
100
- vocab_size=151936,
101
- hidden_size=4096,
102
- intermediate_size=22016,
103
- num_hidden_layers=32,
104
- num_attention_heads=32,
105
- num_key_value_heads=32,
106
- hidden_act="silu",
107
- max_position_embeddings=32768,
108
- initializer_range=0.02,
109
- rms_norm_eps=1e-6,
110
- use_cache=True,
111
- tie_word_embeddings=False,
112
- rope_theta=10000.0,
113
- use_sliding_window=False,
114
- sliding_window=4096,
115
- rope_scaling=None,
116
- max_window_layers=28,
117
- attention_dropout=0.0,
118
- beacon_window=1024,
119
- beacon_stride=1024,
120
- beacon_attn="full-coverage",
121
- beacon_ratio=[2,4,8,16,32],
122
- beacon_ratio_mix="step-random",
123
- beacon_param=[],
124
- beacon_embed_init="eos",
125
- beacon_sink_size=0,
126
- beacon_attend_prev=True,
127
- beacon_pos="interleave",
128
- beacon_parallel_window=1,
129
- **kwargs,
130
- ):
131
- self.vocab_size = vocab_size
132
- self.max_position_embeddings = max_position_embeddings
133
- self.hidden_size = hidden_size
134
- self.intermediate_size = intermediate_size
135
- self.num_hidden_layers = num_hidden_layers
136
- self.num_attention_heads = num_attention_heads
137
- self.use_sliding_window = use_sliding_window
138
- self.sliding_window = sliding_window
139
- self.max_window_layers = max_window_layers
140
- self.rope_scaling = rope_scaling
141
-
142
- # for backward compatibility
143
- if num_key_value_heads is None:
144
- num_key_value_heads = num_attention_heads
145
-
146
- self.num_key_value_heads = num_key_value_heads
147
- self.hidden_act = hidden_act
148
- self.initializer_range = initializer_range
149
- self.rms_norm_eps = rms_norm_eps
150
- self.use_cache = use_cache
151
- self.rope_theta = rope_theta
152
- self.attention_dropout = attention_dropout
153
-
154
- self.beacon_window = beacon_window
155
- self.beacon_stride = beacon_stride
156
- self.beacon_attn = beacon_attn
157
- self.beacon_ratio = beacon_ratio
158
- self.beacon_ratio_mix = beacon_ratio_mix
159
- self.beacon_param = beacon_param
160
- self.beacon_embed_init = beacon_embed_init
161
- self.beacon_sink_size = beacon_sink_size
162
- self.beacon_attend_prev = beacon_attend_prev
163
- self.beacon_pos = beacon_pos
164
- self.beacon_parallel_window = beacon_parallel_window
165
-
166
- super().__init__(
167
- tie_word_embeddings=tie_word_embeddings,
168
- **kwargs,
169
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
llava_arch.py CHANGED
@@ -1,17 +1,3 @@
1
- # Copyright 2023 Haotian Liu
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
  from abc import ABC, abstractmethod
16
  import importlib.util
17
  import os.path as osp
@@ -26,25 +12,6 @@ import torch.nn.functional as F
26
  from .multimodal_encoder.builder import build_vision_tower
27
  from .multimodal_resampler.builder import build_vision_resampler
28
  from .multimodal_projector.builder import build_vision_projector
29
- # except ModuleNotFoundError:
30
- # spec = importlib.util.spec_from_file_location(
31
- # "builder",
32
- # osp.join(osp.dirname(__file__), "builder.py"),
33
- # )
34
- # builder = importlib.util.module_from_spec(spec)
35
- # spec.loader.exec_module(builder)
36
- # build_vision_tower = getattr(
37
- # builder,
38
- # "build_vision_tower",
39
- # )
40
- # build_vision_resampler = getattr(
41
- # builder,
42
- # "build_vision_resampler",
43
- # )
44
- # build_vision_projector = getattr(
45
- # builder,
46
- # "build_vision_projector",
47
- # )
48
 
49
  from transformers import AutoTokenizer
50
 
@@ -59,7 +26,6 @@ import torch.nn.functional as F
59
  import pdb
60
 
61
  class LlavaMetaModel:
62
-
63
  def __init__(self, config):
64
  super(LlavaMetaModel, self).__init__(config)
65
 
@@ -72,31 +38,13 @@ class LlavaMetaModel:
72
  if "unpad" in getattr(config, "mm_patch_merge_type", ""):
73
  self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
74
 
75
- # self.llm_tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
76
  self.hidden_size=config.hidden_size
77
- # print(config)
78
- # exit(0)
79
-
80
- # self.text_tokenizer = T5Tokenizer.from_pretrained('google-t5/t5-small')
81
- ##############################################################################
82
- # self.text_select_model = T5EncoderModel.from_pretrained('google-t5/t5-small')
83
-
84
- # self.text_gamma=0.75
85
-
86
- ###############################################################################
87
  self.text_mlp=nn.Sequential(
88
  nn.Linear(config.hidden_size,config.hidden_size),
89
  nn.GELU(),
90
  )
91
  self.sae=SiglipAE()
92
- #self.sae.load_state_dict(torch.load('/share/LXRlxr0_0/code/videoxl2/videoxl2/longva/longva/model/encoder.pth'),strict=False)
93
-
94
- ###############################################################################
95
- # self.vision_select=nn.Parameter(
96
- # torch.randn((4, self.config.hidden_size), dtype=self.dtype)
97
- # )
98
- ##############################################################################
99
-
100
  def get_vision_tower(self):
101
  vision_tower = getattr(self, "vision_tower", None)
102
  if type(vision_tower) is list:
@@ -147,22 +95,6 @@ class LlavaMetaModel:
147
 
148
  self.sae=SiglipAE()
149
  self.sae.load_state_dict(torch.load('/share/LXRlxr0_0/code/videoxl2/videoxl2/longva/longva/model/encoder.pth'),strict=False)
150
- ##############################################################################
151
- # self.vision_select=nn.Parameter(
152
- # torch.randn((30, self.config.hidden_size), dtype=self.dtype)
153
- # )
154
-
155
- # #self.text_tokenizer = T5Tokenizer.from_pretrained('google-t5/t5-small')
156
- # self.text_select_model = T5EncoderModel.from_pretrained('google-t5/t5-small')
157
-
158
- # self.text_mlp=nn.Sequential(
159
- # nn.Linear(512,self.config.hidden_size),
160
- # nn.GELU(),
161
- # # nn.Linear(config.hidden_size,config.hidden_size),
162
- # # nn.GELU(),
163
- # )
164
- ##############################################################################
165
-
166
 
167
  if getattr(self, "mm_projector", None) is None:
168
  self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
@@ -185,15 +117,7 @@ class LlavaMetaModel:
185
  rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
186
  incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
187
  rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
188
-
189
-
190
- # self.vision_select.data = mm_projector_weights["model.vision_select"]
191
-
192
- # self.text_mlp.load_state_dict(get_w(mm_projector_weights, "text_mlp"))
193
-
194
- # self.text_select_model.load_state_dict(get_w(mm_projector_weights, "text_select_model"),strict=False)
195
- #self.vision_tower.load_state_dict(get_w(mm_projector_weights, "vision_tower"),strict=False)
196
-
197
  def unpad_image(tensor, original_size):
198
  """
199
  Unpads a PyTorch tensor of a padded and resized image.
@@ -283,25 +207,30 @@ class LlavaMetaForCausalLM(ABC):
283
  return torch.repeat_interleave(image_features, repeats=4, dim=0)
284
 
285
  def add_video(self, video_features):
286
- if video_features.size(0)<4:
 
 
 
287
  last_feature = video_features[-1:]
288
-
289
- repeated_features = last_feature.repeat(4 - video_features.size(0), 1,1,1)
 
 
290
  expanded_x = torch.cat([video_features, repeated_features], dim=0)
291
  return expanded_x
292
-
293
- repeat_counts = torch.ones(video_features.size(0), dtype=torch.long, device=video_features.device)
294
 
295
- sum_counts=torch.sum(repeat_counts)
296
- if sum_counts % 4!=0:
297
- padding_size = 4 - (sum_counts % 4)
298
- random_indices = torch.randperm(repeat_counts.size(0))[:padding_size].to(video_features.device)
299
- repeat_counts[random_indices] += 1
300
-
301
- expanded_x = torch.repeat_interleave(video_features, repeat_counts, dim=0)
 
 
302
 
303
- return expanded_x
304
-
305
  def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None):
306
  if self.config.enable_chunk_prefill:
307
  chunk_size_for_vision_tower = self.config.prefill_config['chunk_size_for_vision_tower']
@@ -351,28 +280,27 @@ class LlavaMetaForCausalLM(ABC):
351
  torch.cuda.empty_cache()
352
 
353
  chunk_size = chunk_size_for_vision_tower
354
- print(f'chunk_size: {chunk_size}')
355
  all_feat_list = []
356
  for idx, feat in enumerate(per_videos_or_images_features):
357
  for i in range(0, feat.shape[0], chunk_size):
358
- batched_feat = feat[i:i+chunk_size]
359
- batched_feat=self.interpolate(batched_feat) # torch.Size([187, 1152, 24, 24])
360
  if idx in video_idx_in_batch:
361
- batched_feat = self.add_video(batched_feat) # torch.Size([188, 1152, 24, 24])
362
  else:
363
  batched_feat = self.add_image(batched_feat)
364
 
365
  bc,ch,h,w = batched_feat.shape
366
  batched_feat = batched_feat.view(bc//4,ch,4,h,w)
367
 
368
- batched_feat=self.get_model().sae(batched_feat).squeeze(2)
369
  batched_feat = batched_feat.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
 
370
  batched_feat = self.get_model().mm_projector(batched_feat)
371
-
372
-
373
  batched_feat = self.get_2dPool(batched_feat)
374
  all_feat_list.append(batched_feat)
375
-
376
  feat = torch.cat(all_feat_list, dim=0)
377
  # peak_memory_allocated = torch.cuda.max_memory_allocated()
378
  # print(f"sae 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
@@ -380,10 +308,8 @@ class LlavaMetaForCausalLM(ABC):
380
  del per_videos_or_images_features
381
  del all_feat_list
382
  torch.cuda.empty_cache()
383
-
384
  all_videos_or_images_features.append(feat)
385
  return all_videos_or_images_features
386
-
387
 
388
  def interpolate(self,image_features):
389
  b, num_tokens, dim = image_features.shape
@@ -673,7 +599,7 @@ class LlavaMetaForCausalLM(ABC):
673
 
674
  # Truncate sequences to max length as image embeddings can make the sequence longer
675
  tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
676
- # NOTE: qmh 注释
677
  # new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
678
  # new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
679
 
@@ -771,5 +697,4 @@ class LlavaMetaForCausalLM(ABC):
771
  for p in self.get_input_embeddings().parameters():
772
  p.requires_grad = False
773
  for p in self.get_output_embeddings().parameters():
774
- p.requires_grad = False
775
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from abc import ABC, abstractmethod
2
  import importlib.util
3
  import os.path as osp
 
12
  from .multimodal_encoder.builder import build_vision_tower
13
  from .multimodal_resampler.builder import build_vision_resampler
14
  from .multimodal_projector.builder import build_vision_projector
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  from transformers import AutoTokenizer
17
 
 
26
  import pdb
27
 
28
  class LlavaMetaModel:
 
29
  def __init__(self, config):
30
  super(LlavaMetaModel, self).__init__(config)
31
 
 
38
  if "unpad" in getattr(config, "mm_patch_merge_type", ""):
39
  self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
40
 
 
41
  self.hidden_size=config.hidden_size
 
 
 
 
 
 
 
 
 
 
42
  self.text_mlp=nn.Sequential(
43
  nn.Linear(config.hidden_size,config.hidden_size),
44
  nn.GELU(),
45
  )
46
  self.sae=SiglipAE()
47
+
 
 
 
 
 
 
 
48
  def get_vision_tower(self):
49
  vision_tower = getattr(self, "vision_tower", None)
50
  if type(vision_tower) is list:
 
95
 
96
  self.sae=SiglipAE()
97
  self.sae.load_state_dict(torch.load('/share/LXRlxr0_0/code/videoxl2/videoxl2/longva/longva/model/encoder.pth'),strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  if getattr(self, "mm_projector", None) is None:
100
  self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
 
117
  rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
118
  incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
119
  rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
120
+
 
 
 
 
 
 
 
 
121
  def unpad_image(tensor, original_size):
122
  """
123
  Unpads a PyTorch tensor of a padded and resized image.
 
207
  return torch.repeat_interleave(image_features, repeats=4, dim=0)
208
 
209
  def add_video(self, video_features):
210
+ # Current batch size
211
+ current_batch_size = video_features.size(0)
212
+ # Handle cases where the batch size is less than 4
213
+ if current_batch_size < 4:
214
  last_feature = video_features[-1:]
215
+ # Calculate how many times the last feature needs to be repeated
216
+ num_repeats = 4 - current_batch_size
217
+ repeated_features = last_feature.repeat(num_repeats, 1, 1, 1)
218
+ # Concatenate original features with repeated last feature
219
  expanded_x = torch.cat([video_features, repeated_features], dim=0)
220
  return expanded_x
 
 
221
 
222
+ # Handle cases where the batch size is 4 or greater, but not a multiple of 4
223
+ if current_batch_size % 4 != 0:
224
+ last_feature = video_features[-1:]
225
+ # Calculate how many features are needed to reach the next multiple of 4
226
+ padding_size = 4 - (current_batch_size % 4)
227
+ repeated_features = last_feature.repeat(padding_size, 1, 1, 1)
228
+ # Concatenate original features with repeated last feature
229
+ expanded_x = torch.cat([video_features, repeated_features], dim=0)
230
+ return expanded_x
231
 
232
+ # If the batch size is already a multiple of 4, return as is
233
+ return video_features
234
  def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None):
235
  if self.config.enable_chunk_prefill:
236
  chunk_size_for_vision_tower = self.config.prefill_config['chunk_size_for_vision_tower']
 
280
  torch.cuda.empty_cache()
281
 
282
  chunk_size = chunk_size_for_vision_tower
283
+ # print(f'chunk_size: {chunk_size}')
284
  all_feat_list = []
285
  for idx, feat in enumerate(per_videos_or_images_features):
286
  for i in range(0, feat.shape[0], chunk_size):
287
+ batched_feat = feat[i:i+chunk_size] # chunk_size = 48, batched_feat.shape=[48, 729, 1152]
288
+ batched_feat=self.interpolate(batched_feat) # 插值后 batched_feat.shape=[48, 1152, 24, 24]
289
  if idx in video_idx_in_batch:
290
+ batched_feat = self.add_video(batched_feat) # 第一纬度补充到4的倍数
291
  else:
292
  batched_feat = self.add_image(batched_feat)
293
 
294
  bc,ch,h,w = batched_feat.shape
295
  batched_feat = batched_feat.view(bc//4,ch,4,h,w)
296
 
297
+ batched_feat = self.get_model().sae(batched_feat).squeeze(2)
298
  batched_feat = batched_feat.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
299
+
300
  batched_feat = self.get_model().mm_projector(batched_feat)
 
 
301
  batched_feat = self.get_2dPool(batched_feat)
302
  all_feat_list.append(batched_feat)
303
+
304
  feat = torch.cat(all_feat_list, dim=0)
305
  # peak_memory_allocated = torch.cuda.max_memory_allocated()
306
  # print(f"sae 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
 
308
  del per_videos_or_images_features
309
  del all_feat_list
310
  torch.cuda.empty_cache()
 
311
  all_videos_or_images_features.append(feat)
312
  return all_videos_or_images_features
 
313
 
314
  def interpolate(self,image_features):
315
  b, num_tokens, dim = image_features.shape
 
599
 
600
  # Truncate sequences to max length as image embeddings can make the sequence longer
601
  tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
602
+ # NOTE: qmh
603
  # new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
604
  # new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
605
 
 
697
  for p in self.get_input_embeddings().parameters():
698
  p.requires_grad = False
699
  for p in self.get_output_embeddings().parameters():
700
+ p.requires_grad = False
 
llava_qwen.py CHANGED
@@ -11,8 +11,6 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
-
15
-
16
  from typing import List, Optional, Tuple, Union, Dict
17
  import torch
18
  import torch.nn as nn
@@ -21,9 +19,9 @@ import transformers
21
  from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
22
  from transformers.modeling_outputs import CausalLMOutputWithPast
23
  from transformers.generation.utils import GenerateOutput
24
- from .llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
25
- # from longva.longva.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
26
  from .modeling_qwen2 import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
 
27
  import pdb
28
  import time
29
  import random
@@ -35,7 +33,671 @@ import PIL
35
  from decord import VideoReader, cpu
36
  from .conversation import conv_templates, SeparatorStyle
37
  from .constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_TOKEN
38
- from .mm_utils import tokenizer_image_token, load_video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  class LlavaQwenConfig(Qwen2Config):
@@ -518,7 +1180,6 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
518
  )
519
 
520
  if inputs_embeds is None:
521
- pdb.set_trace()
522
  (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes, time_embedding)
523
 
524
  if self.config.enable_chunk_prefill:
@@ -600,8 +1261,6 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
600
  **kwargs,
601
  ) -> Union[GenerateOutput, torch.LongTensor]:
602
 
603
-
604
-
605
  position_ids = kwargs.pop("position_ids", None)
606
  attention_mask = kwargs.pop("attention_mask", None)
607
 
@@ -664,9 +1323,14 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
664
  prompt = conv.get_prompt()
665
  input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.model.device)
666
 
 
 
 
 
 
667
  # prepare video input
668
  frames, timestamps = load_video(video_path, max_num_frames, fps=sample_fps, max_fps=max_sample_fps)
669
- print(f'video has loaded, extratc {len(frames)} frames.')
670
 
671
  time_stamps=[]
672
  token_frames_sum=(len(timestamps)+3)//4
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
 
 
14
  from typing import List, Optional, Tuple, Union, Dict
15
  import torch
16
  import torch.nn as nn
 
19
  from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
20
  from transformers.modeling_outputs import CausalLMOutputWithPast
21
  from transformers.generation.utils import GenerateOutput
22
+ # from .llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
 
23
  from .modeling_qwen2 import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
24
+ # from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
25
  import pdb
26
  import time
27
  import random
 
33
  from decord import VideoReader, cpu
34
  from .conversation import conv_templates, SeparatorStyle
35
  from .constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_TOKEN
36
+ from .mm_utils import tokenizer_image_token, load_video, KeywordsStoppingCriteria, get_anyres_image_grid_shape
37
+ import math
38
+ import re
39
+ from .vision_tower_builder import build_vision_tower
40
+ from .vision_resampler_builder import build_vision_resampler
41
+ from .vision_projector_builder import build_vision_projector
42
+ from .utils import rank0_print
43
+ from .sae import SiglipAE
44
+ import numpy as np
45
+ import pdb
46
+ from abc import ABC, abstractmethod
47
+
48
+ class LlavaMetaModel:
49
+ def __init__(self, config):
50
+ super(LlavaMetaModel, self).__init__(config)
51
+
52
+ if hasattr(config, "mm_vision_tower"):
53
+ delay_load = getattr(config, "delay_load", False)
54
+ self.vision_tower = build_vision_tower(config, delay_load=delay_load)
55
+ self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
56
+ self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)
57
+
58
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
59
+ self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
60
+
61
+ self.hidden_size=config.hidden_size
62
+ self.text_mlp=nn.Sequential(
63
+ nn.Linear(config.hidden_size,config.hidden_size),
64
+ nn.GELU(),
65
+ )
66
+ self.sae=SiglipAE()
67
+
68
+ def get_vision_tower(self):
69
+ vision_tower = getattr(self, "vision_tower", None)
70
+ if type(vision_tower) is list:
71
+ vision_tower = vision_tower[0]
72
+ return vision_tower
73
+
74
+ def initialize_vision_modules(self, model_args, fsdp=None):
75
+ vision_tower = model_args.vision_tower
76
+ mm_vision_select_layer = model_args.mm_vision_select_layer
77
+ mm_vision_select_feature = model_args.mm_vision_select_feature
78
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
79
+ mm_patch_merge_type = model_args.mm_patch_merge_type
80
+
81
+ self.config.mm_vision_tower = vision_tower
82
+ self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
83
+
84
+ if self.get_vision_tower() is None:
85
+ vision_tower = build_vision_tower(model_args)
86
+ vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
87
+ for k, v in vision_resampler.config.items():
88
+ setattr(self.config, k, v)
89
+
90
+ if fsdp is not None and len(fsdp) > 0:
91
+ self.vision_tower = [vision_tower]
92
+ self.vision_resampler = [vision_resampler]
93
+ else:
94
+ self.vision_tower = vision_tower
95
+ self.vision_resampler = vision_resampler
96
+ else:
97
+ if fsdp is not None and len(fsdp) > 0:
98
+ vision_resampler = self.vision_resampler[0]
99
+ vision_tower = self.vision_tower[0]
100
+ else:
101
+ vision_resampler = self.vision_resampler
102
+ vision_tower = self.vision_tower
103
+ vision_tower.load_model()
104
+
105
+ # In case it is frozen by LoRA
106
+ for p in self.vision_resampler.parameters():
107
+ p.requires_grad = True
108
+
109
+ self.config.use_mm_proj = True
110
+ self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
111
+ self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
112
+ self.config.mm_vision_select_layer = mm_vision_select_layer
113
+ self.config.mm_vision_select_feature = mm_vision_select_feature
114
+ self.config.mm_patch_merge_type = mm_patch_merge_type
115
+
116
+ self.sae=SiglipAE()
117
+ self.sae.load_state_dict(torch.load('/share/LXRlxr0_0/code/videoxl2/videoxl2/longva/longva/model/encoder.pth'),strict=False)
118
+
119
+ if getattr(self, "mm_projector", None) is None:
120
+ self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
121
+
122
+ if "unpad" in mm_patch_merge_type:
123
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
124
+ self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
125
+ else:
126
+ # In case it is frozen by LoRA
127
+ for p in self.mm_projector.parameters():
128
+ p.requires_grad = True
129
+
130
+ if pretrain_mm_mlp_adapter is not None:
131
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
132
+
133
+ def get_w(weights, keyword):
134
+ return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
135
+
136
+ incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
137
+ rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
138
+ incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
139
+ rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
140
+
141
+ def unpad_image(tensor, original_size):
142
+ """
143
+ Unpads a PyTorch tensor of a padded and resized image.
144
+
145
+ Args:
146
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
147
+ original_size (tuple): The original size of the image (height, width).
148
+
149
+ Returns:
150
+ torch.Tensor: The unpadded image tensor.
151
+ """
152
+ original_width, original_height = original_size
153
+ current_height, current_width = tensor.shape[1:]
154
+
155
+ # Compute aspect ratios
156
+ original_aspect_ratio = original_width / original_height
157
+ current_aspect_ratio = current_width / current_height
158
+
159
+ # Determine padding size and direction
160
+ if original_aspect_ratio > current_aspect_ratio:
161
+ # Padding was added to the height
162
+ scale_factor = current_width / original_width
163
+ new_height = int(original_height * scale_factor)
164
+ padding = (current_height - new_height) // 2
165
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
166
+ else:
167
+ # Padding was added to the width
168
+ scale_factor = current_height / original_height
169
+ new_width = int(original_width * scale_factor)
170
+ padding = (current_width - new_width) // 2
171
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
172
+
173
+ return unpadded_tensor
174
+
175
+ class LlavaMetaForCausalLM(ABC):
176
+ @abstractmethod
177
+ def get_model(self):
178
+ pass
179
+
180
+ def get_vision_tower(self):
181
+ return self.get_model().get_vision_tower()
182
+
183
+ def get_2dPool(self, image_feature):
184
+ height = width = self.get_vision_tower().num_patches_per_side
185
+ num_frames, num_tokens, num_dim = image_feature.shape
186
+ image_feature = image_feature.view(num_frames, height, width, -1)
187
+ image_feature = image_feature.permute(0, 3, 1, 2).contiguous()
188
+ # image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
189
+ if self.config.mm_spatial_pool_mode == "average":
190
+ image_feature = nn.functional.avg_pool2d(image_feature, self.config.mm_spatial_pool_stride)
191
+ elif self.config.mm_spatial_pool_mode == "max":
192
+ image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
193
+ else:
194
+ raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}")
195
+ image_feature = image_feature.permute(0, 2, 3, 1)
196
+ image_feature = image_feature.view(num_frames, -1, num_dim)
197
+ return image_feature
198
+
199
+ def encode_images(self, images):
200
+ image_features = self.get_model().get_vision_tower()(images)
201
+ #image_features = self.get_model().vision_resampler(image_features, images=images)
202
+ image_features = self.get_model().mm_projector(image_features)
203
+ image_features = self.get_model().vision_resampler(image_features, images=images)
204
+ return image_features
205
+
206
+ def add_image(self, image_features):
207
+ return torch.repeat_interleave(image_features, repeats=4, dim=0)
208
+
209
+ def add_video(self, video_features):
210
+ # Current batch size
211
+ current_batch_size = video_features.size(0)
212
+ # Handle cases where the batch size is less than 4
213
+ if current_batch_size < 4:
214
+ last_feature = video_features[-1:]
215
+ # Calculate how many times the last feature needs to be repeated
216
+ num_repeats = 4 - current_batch_size
217
+ repeated_features = last_feature.repeat(num_repeats, 1, 1, 1)
218
+ # Concatenate original features with repeated last feature
219
+ expanded_x = torch.cat([video_features, repeated_features], dim=0)
220
+ return expanded_x
221
+
222
+ # Handle cases where the batch size is 4 or greater, but not a multiple of 4
223
+ if current_batch_size % 4 != 0:
224
+ last_feature = video_features[-1:]
225
+ # Calculate how many features are needed to reach the next multiple of 4
226
+ padding_size = 4 - (current_batch_size % 4)
227
+ repeated_features = last_feature.repeat(padding_size, 1, 1, 1)
228
+ # Concatenate original features with repeated last feature
229
+ expanded_x = torch.cat([video_features, repeated_features], dim=0)
230
+ return expanded_x
231
+
232
+ # If the batch size is already a multiple of 4, return as is
233
+ return video_features
234
+ def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None):
235
+ if self.config.enable_chunk_prefill:
236
+ chunk_size_for_vision_tower = self.config.prefill_config['chunk_size_for_vision_tower']
237
+ else:
238
+ chunk_size_for_vision_tower = 100000
239
+ # pdb.set_trace()
240
+ # Define the maximum batch size (1024 frames)
241
+ max_batch_size = chunk_size_for_vision_tower
242
+ # print(f'max_batch_size: {max_batch_size}')
243
+ num_frames = videos_or_images.shape[0]
244
+ # Initialize a list to store the features from each batch
245
+ videos_or_images_features = []
246
+
247
+ videos_or_images_features = torch.empty((num_frames, 729, 1152), device=self.get_model().device, dtype=self.get_model().dtype)
248
+
249
+ # Split videos_or_images into smaller batches if num_frames > max_batch_size
250
+ current_idx = 0
251
+ if num_frames > max_batch_size:
252
+ # Calculate the number of batches needed
253
+ num_batches = (num_frames + max_batch_size - 1) // max_batch_size
254
+ for i in range(num_batches):
255
+ start_idx = i * max_batch_size
256
+ end_idx = min((i + 1) * max_batch_size, num_frames)
257
+
258
+ # Process each batch separately
259
+ batch_videos_or_images = videos_or_images[start_idx:end_idx]
260
+ batch_features = self.get_model().get_vision_tower()(batch_videos_or_images)
261
+ # videos_or_images_features.append(batch_features)
262
+
263
+ videos_or_images_features[current_idx:current_idx + batch_features.shape[0]] = batch_features
264
+ # Update the current index for the next batch
265
+ current_idx += batch_features.shape[0]
266
+ # peak_memory_allocated = torch.cuda.max_memory_allocated()
267
+ # print(f"vision encoder 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
268
+
269
+ # Concatenate the features of all batches
270
+ # videos_or_images_features = torch.cat(videos_or_images_features, dim=0)
271
+ else:
272
+ videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images)
273
+
274
+ per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0)
275
+ all_videos_or_images_features = []
276
+
277
+ # peak_memory_allocated = torch.cuda.max_memory_allocated()
278
+ # print(f"vision encoder 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
279
+ del videos_or_images_features
280
+ torch.cuda.empty_cache()
281
+
282
+ chunk_size = chunk_size_for_vision_tower
283
+ # print(f'chunk_size: {chunk_size}')
284
+ all_feat_list = []
285
+ for idx, feat in enumerate(per_videos_or_images_features):
286
+ for i in range(0, feat.shape[0], chunk_size):
287
+ batched_feat = feat[i:i+chunk_size] # chunk_size = 48, batched_feat.shape=[48, 729, 1152]
288
+ batched_feat=self.interpolate(batched_feat) # 插值后 batched_feat.shape=[48, 1152, 24, 24]
289
+ if idx in video_idx_in_batch:
290
+ batched_feat = self.add_video(batched_feat) # 第一纬度补充到4的倍数
291
+ else:
292
+ batched_feat = self.add_image(batched_feat)
293
+
294
+ bc,ch,h,w = batched_feat.shape
295
+ batched_feat = batched_feat.view(bc//4,ch,4,h,w)
296
+
297
+ batched_feat = self.get_model().sae(batched_feat).squeeze(2)
298
+ batched_feat = batched_feat.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
299
+
300
+ batched_feat = self.get_model().mm_projector(batched_feat)
301
+ batched_feat = self.get_2dPool(batched_feat)
302
+ all_feat_list.append(batched_feat)
303
+
304
+ feat = torch.cat(all_feat_list, dim=0)
305
+ # peak_memory_allocated = torch.cuda.max_memory_allocated()
306
+ # print(f"sae 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
307
+
308
+ del per_videos_or_images_features
309
+ del all_feat_list
310
+ torch.cuda.empty_cache()
311
+ all_videos_or_images_features.append(feat)
312
+ return all_videos_or_images_features
313
+
314
+ def interpolate(self,image_features):
315
+ b, num_tokens, dim = image_features.shape
316
+
317
+ #print(str(image_features.shape)+' i\n')
318
+
319
+ target_h = target_w = int(576**0.5)
320
+ h = w = int(num_tokens**0.5)
321
+
322
+ image_features = image_features.view(b, h, w, dim)
323
+ image_features = image_features.permute(0, 3, 1, 2).contiguous()
324
+
325
+ chunk_size = 24
326
+ chunks = torch.split(image_features, chunk_size, dim=0)
327
+ interpolated_chunks = []
328
+ for chunk in chunks:
329
+ interpolated_chunk = F.interpolate(
330
+ chunk.to(torch.float32),
331
+ size=(target_h, target_w),
332
+ mode="bilinear",
333
+ align_corners=False,
334
+ ).to(chunk.dtype)
335
+ interpolated_chunks.append(interpolated_chunk)
336
+ image_features = torch.cat(interpolated_chunks, dim=0)
337
+ del interpolated_chunks
338
+
339
+ del chunks
340
+
341
+ return image_features
342
+
343
+ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None,time_embedding=None):
344
+
345
+ vision_tower = self.get_vision_tower()
346
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
347
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
348
+
349
+ if type(images) is list or images.ndim == 5:
350
+ if type(images) is list:
351
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
352
+
353
+ video_idx_in_batch = []
354
+ for _ in range(len(modalities)):
355
+ if modalities[_] == "video":
356
+ video_idx_in_batch.append(_)
357
+
358
+ images_list = []
359
+ for image in images:
360
+ if image.ndim == 4:
361
+ images_list.append(image)
362
+ else:
363
+ images_list.append(image.unsqueeze(0))
364
+ #print(len(images_list),images_list[0].shape)
365
+
366
+ concat_images = torch.cat([image for image in images_list], dim=0)
367
+ split_sizes = [image.shape[0] for image in images_list]
368
+
369
+ image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) #16,144,3584
370
+
371
+ mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
372
+ image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
373
+
374
+ visual_drop_score=[]
375
+ new_image_features=[]
376
+
377
+ if mm_patch_merge_type == "flat":
378
+
379
+ if image_features[0].ndim>2:
380
+ image_features = [x.flatten(0, 1) for x in image_features]
381
+ elif mm_patch_merge_type== "unires":
382
+ #print('unires')
383
+ for image_idx, image_feature in enumerate(image_features):
384
+ # rank0_print(f"Initial feature size : {image_feature.shape}")
385
+ if image_idx in video_idx_in_batch: # video operations
386
+ #print(image_feature.shape)
387
+ image_feature = image_feature.flatten(0, 1)
388
+
389
+ elif image_feature.shape[0] > 1:
390
+ # base image feature is never used in unires
391
+ base_image_feature = image_feature[0]
392
+ image_feature = image_feature[1:]
393
+
394
+ height = width = self.get_vision_tower().num_patches_per_side
395
+ assert height * width == base_image_feature.shape[0]
396
+
397
+ kernel_size = mm_patch_merge_type.split("avgpool")[-1].split("x")[-1]
398
+ kernel_size = 2
399
+ image_feature = image_feature.view(image_feature.shape[0], height, width, -1) # [4, 24, 24, 4096]
400
+ image_feature = image_feature.permute(0, 3, 1, 2).contiguous() # [4, 4096, 24, 24]
401
+ image_feature = nn.functional.avg_pool2d(image_feature,kernel_size) # [4, 4096, 12, 12]
402
+ image_feature = image_feature.flatten(2, 3) # [4, 4096, 144]
403
+ image_feature = image_feature.permute(0, 2, 1).contiguous() # [4, 144, 4096]
404
+
405
+ #print(image_feature.shape)
406
+ image_feature = image_feature.flatten(0, 1)
407
+
408
+ else:
409
+
410
+ image_feature = image_feature[0]
411
+
412
+ new_image_features.append(image_feature)
413
+
414
+ image_features = new_image_features
415
+
416
+ elif mm_patch_merge_type.startswith("spatial"):
417
+ new_image_features = []
418
+ for image_idx, image_feature in enumerate(image_features):
419
+ # FIXME: now assume the image is square, and split to 2x2 patches
420
+ # num_patches = h * w, where h = w = sqrt(num_patches)
421
+ # currently image_feature is a tensor of shape (4, num_patches, hidden_size)
422
+ # we want to first unflatten it to (2, 2, h, w, hidden_size)
423
+ if image_idx in video_idx_in_batch: # video operations
424
+ if "unpad" in mm_patch_merge_type:
425
+ # image_feature = image_feature.permute(2, 0, 1).contiguous()
426
+ # image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
427
+ # image_feature = image_feature.permute(1, 2, 0).contiguous()
428
+ image_feature = image_feature.flatten(0, 1)
429
+ image_feature = torch.cat((image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0)
430
+
431
+ elif image_feature.shape[0] > 1: # multi patches and multi images operations
432
+ base_image_feature = image_feature[0]
433
+ image_feature = image_feature[1:]
434
+ height = width = self.get_vision_tower().num_patches_per_side
435
+ assert height * width == base_image_feature.shape[0]
436
+
437
+ if "anyres_max" in image_aspect_ratio:
438
+ matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio)
439
+ if matched_anyres_max_num_patches:
440
+ max_num_patches = int(matched_anyres_max_num_patches.group(1))
441
+
442
+ if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
443
+ if hasattr(self.get_vision_tower(), "image_size"):
444
+ vision_tower_image_size = self.get_vision_tower().image_size
445
+ else:
446
+ raise ValueError("vision_tower_image_size is not found in the vision tower.")
447
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size)
448
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
449
+ else:
450
+ image_feature = image_feature.view(2, 2, height, width, -1)
451
+
452
+ if "maxpool2x2" in mm_patch_merge_type:
453
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
454
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
455
+ image_feature = nn.functional.max_pool2d(image_feature, 2)
456
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
457
+ elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
458
+ unit = image_feature.shape[2]
459
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
460
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
461
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
462
+ c, h, w = image_feature.shape
463
+ times = math.sqrt(h * w / (max_num_patches * unit**2))
464
+ if times > 1.1:
465
+ image_feature = image_feature[None]
466
+ image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0]
467
+ image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
468
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
469
+ elif "unpad" in mm_patch_merge_type:
470
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
471
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
472
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
473
+ image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
474
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
475
+ else:
476
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
477
+ image_feature = image_feature.flatten(0, 3)
478
+ if "nobase" in mm_patch_merge_type:
479
+ pass
480
+ else:
481
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
482
+ else: # single image operations
483
+ image_feature = image_feature[0]
484
+ if "unpad" in mm_patch_merge_type:
485
+ image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0)
486
+
487
+ new_image_features.append(image_feature)
488
+ image_features = new_image_features
489
+ else:
490
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
491
+ else:
492
+ error_message = """
493
+ Something is wrong with the input shape. Most likely, you did not wrap the image or video input in a list:
494
+ This is correct:
495
+ model.generate(input_ids, images=[video_tensor], modalities=["video"], **gen_kwargs)
496
+ model.generate(input_ids, images=[image_tensor], modalities=["image"], **gen_kwargs)
497
+ This is wrong:
498
+ model.generate(input_ids, images=video_tensor, modalities=["video"], **gen_kwargs)
499
+ model.generate(input_ids, images=image_tensor, modalities=["image"], **gen_kwargs)
500
+ """
501
+ raise ValueError(error_message)
502
+
503
+ #print(time_embedding[0].shape)
504
+ #video_token_indices=[]
505
+ for image_idx, image_feature in enumerate(image_features):
506
+ if time_embedding[image_idx] is not None:
507
+ mask = (time_embedding[image_idx] == 151654)
508
+ indices = torch.nonzero(mask).squeeze()
509
+
510
+ embed_token=self.get_model().embed_tokens(time_embedding[image_idx])
511
+ embed_token[indices]=image_features[image_idx]
512
+
513
+ #video_token_indices.append(indices)
514
+
515
+ image_features[image_idx]=embed_token
516
+
517
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
518
+ raise NotImplementedError
519
+
520
+ # Let's just add dummy tensors if they do not exist,
521
+ # it is a headache to deal with None all the time.
522
+ # But it is not ideal, and if you have a better idea,
523
+ # please open an issue / submit a PR, thanks.
524
+ _labels = labels
525
+ _position_ids = position_ids
526
+ _attention_mask = attention_mask
527
+ if attention_mask is None:
528
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
529
+ else:
530
+ attention_mask = attention_mask.bool()
531
+ if position_ids is None:
532
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
533
+ if labels is None:
534
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
535
+
536
+ # remove the padding using attention_mask -- FIXME
537
+ _input_ids = input_ids
538
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
539
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
540
+
541
+ new_input_embeds = []
542
+ new_labels = []
543
+ cur_image_idx = 0
544
+
545
+ for batch_idx, cur_input_ids in enumerate(input_ids):
546
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
547
+ #print(num_images)
548
+ if num_images>=2:
549
+ print(num_images,input_ids)
550
+ if num_images == 0:
551
+ cur_image_features = image_features[cur_image_idx]
552
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
553
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
554
+ new_input_embeds.append(cur_input_embeds)
555
+ new_labels.append(labels[batch_idx])
556
+ cur_image_idx += 1
557
+ continue
558
+
559
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
560
+ #print(image_token_indices) #[-1, 14, 236]
561
+ cur_input_ids_noim = []
562
+ cur_labels = labels[batch_idx]
563
+
564
+ # print(cur_input_ids)
565
+ # print(labels[batch_idx])
566
+
567
+ cur_labels_noim = []
568
+ for i in range(len(image_token_indices) - 1):
569
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
570
+ cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
571
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
572
+
573
+ #print(torch.cat(cur_input_ids_noim).shape,torch.cat(cur_input_ids_noim))
574
+
575
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
576
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
577
+ cur_new_input_embeds = []
578
+ cur_new_labels = []
579
+
580
+ for i in range(num_images + 1):
581
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
582
+ cur_new_labels.append(cur_labels_noim[i])
583
+ if i < num_images:
584
+ ##############
585
+ cur_image_features = image_features[cur_image_idx]
586
+ cur_image_idx += 1
587
+ cur_new_input_embeds.append(cur_image_features)
588
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
589
+
590
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
591
+
592
+ # import pdb; pdb.set_trace()
593
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
594
+
595
+ cur_new_labels = torch.cat(cur_new_labels)
596
+
597
+ new_input_embeds.append(cur_new_input_embeds)
598
+ new_labels.append(cur_new_labels)
599
+
600
+ # Truncate sequences to max length as image embeddings can make the sequence longer
601
+ tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
602
+ # NOTE: qmh
603
+ # new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
604
+ # new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
605
+
606
+ # TODO: Hard code for control loss spike
607
+ # if tokenizer_model_max_length is not None:
608
+ # new_input_embeds = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
609
+ # new_labels = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
610
+
611
+ # Combine them
612
+ max_len = max(x.shape[0] for x in new_input_embeds)
613
+ batch_size = len(new_input_embeds)
614
+
615
+ new_input_embeds_padded = []
616
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
617
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
618
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
619
+
620
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
621
+ cur_len = cur_new_embed.shape[0]
622
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
623
+ new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
624
+ if cur_len > 0:
625
+ new_labels_padded[i, -cur_len:] = cur_new_labels
626
+ attention_mask[i, -cur_len:] = True
627
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
628
+ else:
629
+ new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
630
+ if cur_len > 0:
631
+ new_labels_padded[i, :cur_len] = cur_new_labels
632
+ attention_mask[i, :cur_len] = True
633
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
634
+
635
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
636
+
637
+ if _labels is None:
638
+ new_labels = None
639
+ else:
640
+ new_labels = new_labels_padded
641
+
642
+ if _attention_mask is None:
643
+ attention_mask = None
644
+ else:
645
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
646
+
647
+ if _position_ids is None:
648
+ position_ids = None
649
+ if getattr(self.config, "use_pos_skipping", False) and self.training:
650
+ position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device)
651
+ split_position = random.randint(0, new_input_embeds.size(1))
652
+ left_add = random.randint(0, self.config.pos_skipping_range)
653
+ right_add = random.randint(left_add, self.config.pos_skipping_range)
654
+ position_ids[:, :split_position] += left_add
655
+ position_ids[:, split_position:] += right_add
656
+ # import pdb; pdb.set_trace()
657
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
658
+
659
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
660
+ if model_args.mm_use_im_patch_token:
661
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
662
+ self.resize_token_embeddings(len(tokenizer))
663
+
664
+ if model_args.mm_use_im_start_end:
665
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
666
+ self.resize_token_embeddings(len(tokenizer))
667
+
668
+ if num_new_tokens > 0:
669
+ input_embeddings = self.get_input_embeddings().weight.data
670
+ output_embeddings = self.get_output_embeddings().weight.data
671
+
672
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
673
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
674
+
675
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
676
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
677
+
678
+ if model_args.tune_mm_mlp_adapter:
679
+ for p in self.get_input_embeddings().parameters():
680
+ p.requires_grad = True
681
+ for p in self.get_output_embeddings().parameters():
682
+ p.requires_grad = False
683
+
684
+ if model_args.pretrain_mm_mlp_adapter:
685
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
686
+ embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
687
+ assert num_new_tokens == 2
688
+ if input_embeddings.shape == embed_tokens_weight.shape:
689
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
690
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
691
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
692
+ else:
693
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
694
+
695
+ elif model_args.mm_use_im_patch_token:
696
+ if model_args.tune_mm_mlp_adapter:
697
+ for p in self.get_input_embeddings().parameters():
698
+ p.requires_grad = False
699
+ for p in self.get_output_embeddings().parameters():
700
+ p.requires_grad = False
701
 
702
 
703
  class LlavaQwenConfig(Qwen2Config):
 
1180
  )
1181
 
1182
  if inputs_embeds is None:
 
1183
  (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes, time_embedding)
1184
 
1185
  if self.config.enable_chunk_prefill:
 
1261
  **kwargs,
1262
  ) -> Union[GenerateOutput, torch.LongTensor]:
1263
 
 
 
1264
  position_ids = kwargs.pop("position_ids", None)
1265
  attention_mask = kwargs.pop("attention_mask", None)
1266
 
 
1323
  prompt = conv.get_prompt()
1324
  input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.model.device)
1325
 
1326
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
1327
+ keywords = [stop_str]
1328
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
1329
+ generation_config["stopping_criteria"] = [stopping_criteria]
1330
+
1331
  # prepare video input
1332
  frames, timestamps = load_video(video_path, max_num_frames, fps=sample_fps, max_fps=max_sample_fps)
1333
+ print(f'video has loaded, extract {len(frames)} frames.')
1334
 
1335
  time_stamps=[]
1336
  token_frames_sum=(len(timestamps)+3)//4
modeling_qwen2.py CHANGED
@@ -503,10 +503,12 @@ class Qwen2FlashAttention2(Qwen2Attention):
503
  hidden_states: torch.Tensor,
504
  attention_mask: Optional[torch.Tensor] = None,
505
  position_ids: Optional[torch.LongTensor] = None,
 
506
  past_key_value: Optional[Cache] = None,
507
  output_attentions: bool = False,
508
  use_cache: bool = False,
509
  cache_position: Optional[torch.LongTensor] = None,
 
510
  ):
511
  bsz, q_len, _ = hidden_states.size()
512
 
 
503
  hidden_states: torch.Tensor,
504
  attention_mask: Optional[torch.Tensor] = None,
505
  position_ids: Optional[torch.LongTensor] = None,
506
+ key_position_ids: Optional[torch.LongTensor] = None,
507
  past_key_value: Optional[Cache] = None,
508
  output_attentions: bool = False,
509
  use_cache: bool = False,
510
  cache_position: Optional[torch.LongTensor] = None,
511
+ blocks_positions=None,
512
  ):
513
  bsz, q_len, _ = hidden_states.size()
514
 
multimodal_encoder/.ipynb_checkpoints/base_encoder-checkpoint.py DELETED
@@ -1,68 +0,0 @@
1
- from abc import ABC, abstractmethod
2
-
3
- import torch
4
- import torch.nn as nn
5
-
6
-
7
- class BaseVisionTower(nn.Module):
8
- def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False):
9
- super().__init__()
10
-
11
- self.is_loaded = False
12
-
13
- self.vision_tower_name = vision_tower_name
14
- self.delay_load = delay_load
15
-
16
- @abstractmethod
17
- def load_model(self, device_map=None):
18
- raise NotImplementedError("Subclasses must implement load_model")
19
-
20
- @abstractmethod
21
- def _forward(self, images):
22
- raise NotImplementedError("Subclasses must implement forward")
23
-
24
- def forward(self, images):
25
- if type(images) is list:
26
- image_features = [self._forward(image.unsqueeze(0)) for image in images]
27
- else:
28
- image_features = self._forward(images)
29
-
30
- return image_features
31
-
32
- @property
33
- def dummy_feature(self):
34
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
35
-
36
- @property
37
- def dtype(self):
38
- # Dynamically infer the dtype from the first parameter, if not explicitly specified
39
- if hasattr(self.vision_tower, "dtype"):
40
- return self.vision_tower.dtype
41
- else:
42
- params = list(self.vision_tower.parameters())
43
- return (
44
- params[0].dtype if len(params) > 0 else torch.float32
45
- ) # Default to torch.float32 if no parameters
46
-
47
- @property
48
- def device(self):
49
- # Dynamically infer the device from the first parameter, if not explicitly specified
50
- if hasattr(self.vision_tower, "device"):
51
- return self.vision_tower.device
52
- else:
53
- params = list(self.vision_tower.parameters())
54
- return (
55
- params[0].device if len(params) > 0 else torch.device("cpu")
56
- ) # Default to CPU if no parameters
57
- @property
58
- def config(self):
59
- if self.is_loaded:
60
- return self.vision_tower.config
61
- else:
62
- return self.cfg_only
63
- @property
64
- def hidden_size(self):
65
- try:
66
- return self.config.hidden_size
67
- except:
68
- return self._hidden_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multimodal_encoder/.ipynb_checkpoints/builder-checkpoint.py DELETED
@@ -1,29 +0,0 @@
1
- import os
2
- from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
3
- from .siglip_encoder import SigLipVisionTower
4
- # from .eva_clip.eva_clip_encoder import EvaClipVisionTower
5
- # from .dev_eva_clip.eva_vit import EvaViTWrapper
6
-
7
-
8
- def build_vision_tower(vision_tower_cfg, **kwargs):
9
-
10
- vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
11
- is_absolute_path_exists = os.path.exists(vision_tower)
12
- use_s2 = getattr(vision_tower_cfg, "s2", False)
13
-
14
- #print(getattr(vision_tower_cfg, "vision_tower", None))
15
- return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
16
- if getattr(vision_tower_cfg, "vision_tower", None) and "siglip" in getattr(vision_tower_cfg, "vision_tower", None).lower():
17
- #print('*************\n')
18
- return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
19
- if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
20
- if use_s2:
21
- return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
22
- else:
23
- return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
24
- # elif "internal-eva" in vision_tower.lower() or "eva02" in vision_tower.lower():
25
- # return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
26
- # elif vision_tower in ["EVA-CLIP-8B", "EVA-CLIP-8B-plus"]:
27
- # return EvaViTWrapper(vision_tower, args=vision_tower_cfg, **kwargs)
28
-
29
- raise ValueError(f"Unknown vision tower: {vision_tower}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multimodal_encoder/.ipynb_checkpoints/clip_encoder-checkpoint.py DELETED
@@ -1,179 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from longva.longva.utils import rank0_print
4
- from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
-
6
- try:
7
- from s2wrapper import forward as multiscale_forward
8
- except:
9
- pass
10
-
11
-
12
- class CLIPVisionTower(nn.Module):
13
- def __init__(self, vision_tower, args, delay_load=False):
14
- super().__init__()
15
-
16
- self.is_loaded = False
17
-
18
- self.vision_tower_name = vision_tower
19
- self.select_layer = args.mm_vision_select_layer
20
- self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
21
-
22
- if not delay_load:
23
- rank0_print(f"Loading vision tower: {vision_tower}")
24
- self.load_model()
25
- elif getattr(args, "unfreeze_mm_vision_tower", False):
26
- # TODO: better detector is needed.
27
- rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
28
- self.load_model()
29
- elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
30
- rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
31
- self.load_model()
32
- else:
33
- self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
34
-
35
- def load_model(self, device_map=None):
36
- if self.is_loaded:
37
- rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
38
- return
39
-
40
- self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
41
- self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
42
- self.vision_tower.requires_grad_(False)
43
-
44
- self.is_loaded = True
45
-
46
- def feature_select(self, image_forward_outs):
47
- select_feature_type = self.select_feature
48
-
49
- if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
50
- select_every_k_layer = len(image_forward_outs.hidden_states) // 4
51
- image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1)
52
- select_feature_type = select_feature_type.replace("slicefour_", "")
53
- elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
54
- select_layers = [-2, -5, -8, -11, 6]
55
- image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1)
56
- select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
57
- else:
58
- image_features = image_forward_outs.hidden_states[self.select_layer]
59
-
60
- if select_feature_type == "patch":
61
- image_features = image_features[:, 1:]
62
- elif select_feature_type == "cls_patch":
63
- image_features = image_features
64
- else:
65
- raise ValueError(f"Unexpected select feature: {select_feature_type}")
66
- return image_features
67
-
68
- def forward(self, images):
69
- if type(images) is list:
70
- image_features = []
71
- for image in images:
72
- image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
73
- #print('image_feature before select ',image_forward_out.hidden_states[-1].shape)
74
- image_feature = self.feature_select(image_forward_out).to(image.dtype)
75
- #print('image_feature after select ',image_feature.shape)
76
- image_features.append(image_feature)
77
- else:
78
- image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
79
- #print('image_feature before select ',image_forward_outs.hidden_states[-1].shape)
80
- image_features = self.feature_select(image_forward_outs).to(images.dtype)
81
- #print('image_feature after select ',image_features.shape)
82
-
83
- return image_features
84
-
85
- @property
86
- def dummy_feature(self):
87
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
88
-
89
- @property
90
- def dtype(self):
91
- return self.vision_tower.dtype
92
-
93
- @property
94
- def device(self):
95
- return self.vision_tower.device
96
-
97
- @property
98
- def config(self):
99
- if self.is_loaded:
100
- return self.vision_tower.config
101
- else:
102
- return self.cfg_only
103
-
104
- @property
105
- def hidden_size(self):
106
- _hidden_size = self.config.hidden_size
107
- if "slicefour" in self.select_feature:
108
- _hidden_size *= 4
109
- if "slice_m25811_f6" in self.select_feature:
110
- _hidden_size *= 5
111
- return _hidden_size
112
-
113
- @property
114
- def num_patches_per_side(self):
115
- return self.config.image_size // self.config.patch_size
116
-
117
- @property
118
- def num_patches(self):
119
- _num_patches = (self.config.image_size // self.config.patch_size) ** 2
120
- if "cls_patch" in self.select_feature:
121
- _num_patches += 1
122
- return _num_patches
123
-
124
- @property
125
- def image_size(self):
126
- return self.config.image_size
127
-
128
-
129
- class CLIPVisionTowerS2(CLIPVisionTower):
130
- def __init__(self, vision_tower, args, delay_load=False):
131
-
132
- self.s2_scales = getattr(args, "s2_scales", "336,672,1008")
133
- self.s2_scales = list(map(int, self.s2_scales.split(",")))
134
- self.s2_scales.sort()
135
- self.s2_split_size = self.s2_scales[0]
136
- self.s2_image_size = self.s2_scales[-1]
137
-
138
- super().__init__(vision_tower, args, delay_load)
139
-
140
- # change resize/crop size in preprocessing to the largest image size in s2_scale
141
- if not delay_load or getattr(args, "unfreeze_mm_vision_tower", False):
142
- self.image_processor.size["shortest_edge"] = self.s2_image_size
143
- self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size
144
-
145
- def load_model(self, device_map=None):
146
- if self.is_loaded:
147
- rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
148
- return
149
-
150
- self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
151
- self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
152
- self.vision_tower.requires_grad_(False)
153
-
154
- self.image_processor.size["shortest_edge"] = self.s2_image_size
155
- self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size
156
-
157
- self.is_loaded = True
158
-
159
- @torch.no_grad()
160
- def forward_feature(self, images):
161
- image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
162
- image_features = self.feature_select(image_forward_outs).to(images.dtype)
163
- return image_features
164
-
165
- @torch.no_grad()
166
- def forward(self, images):
167
- if type(images) is list:
168
- image_features = []
169
- for image in images:
170
- image_feature = multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True)
171
- image_features.append(image_feature)
172
- else:
173
- image_features = multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True)
174
-
175
- return image_features
176
-
177
- @property
178
- def hidden_size(self):
179
- return self.config.hidden_size * len(self.s2_scales)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multimodal_encoder/.ipynb_checkpoints/siglip_encoder-checkpoint.py DELETED
@@ -1,151 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
- from typing import Optional, Tuple, Union, Dict
5
- from PIL import Image
6
- from functools import partial, reduce
7
- from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
8
-
9
- from .base_encoder import BaseVisionTower
10
- import torch.distributed as dist
11
- # --data_path /share/shuyan/video_traindata/anno/\{cinepine_order\}.json \
12
- # --image_folder /share/shuyan/video_traindata/Bunny-v1_0-data/finetune/images \
13
- # --video_folder /share/shuyan/video_traindata \
14
- def rank0_print(*args):
15
- if dist.is_initialized():
16
- if dist.get_rank() == 0:
17
- print(f"Rank {dist.get_rank()}: ", *args)
18
- else:
19
- print(*args)
20
-
21
-
22
- from transformers.image_processing_utils import BatchFeature, get_size_dict
23
- from transformers.image_transforms import (
24
- convert_to_rgb,
25
- normalize,
26
- rescale,
27
- resize,
28
- to_channel_dimension_format,
29
- )
30
- from transformers.image_utils import (
31
- ChannelDimension,
32
- PILImageResampling,
33
- to_numpy_array,
34
- )
35
- class SigLipImageProcessor:
36
- def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST):
37
- crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
38
- crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
39
-
40
- self.image_mean = image_mean
41
- self.image_std = image_std
42
- self.size = size
43
- self.resample = resample
44
- self.rescale_factor = rescale_factor
45
- self.data_format = data_format
46
- self.crop_size = crop_size
47
-
48
- def preprocess(self, images, return_tensors):
49
- if isinstance(images, Image.Image):
50
- images = [images]
51
- else:
52
- # to adapt video data
53
- images = [to_numpy_array(image) for image in images]
54
- assert isinstance(images, list)
55
-
56
- transforms = [
57
- convert_to_rgb,
58
- to_numpy_array,
59
- partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
60
- partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
61
- partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
62
- partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
63
- ]
64
-
65
- images = reduce(lambda x, f: [*map(f, x)], transforms, images)
66
-
67
- data = {"pixel_values": images}
68
-
69
- return BatchFeature(data=data, tensor_type=return_tensors)
70
-
71
- class SigLipVisionTower(BaseVisionTower):
72
- def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False):
73
- super(SigLipVisionTower, self).__init__(vision_tower_name, vision_tower_cfg, delay_load)
74
-
75
- model_path = "google/siglip-so400m-patch14-384"
76
- base_model_name, res, interp = model_path, 384, 576
77
- self.vision_tower_name = base_model_name
78
- self._image_size = res if res is not None else 512
79
- self.unfreeze_mm_vision_tower = getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False)
80
-
81
- if not delay_load:
82
- rank0_print(f"Loading vision tower: {vision_tower_name}")
83
- self.load_model()
84
- elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False):
85
- # TODO: better detector is needed.
86
- rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
87
- self.load_model()
88
- elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts:
89
- rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
90
- self.load_model()
91
- else:
92
- self.cfg_only = self.config
93
-
94
- def load_model(self, device_map=None):
95
- self.vision_model = "siglip"
96
- # clip_model, processor = create_model_from_pretrained(self.vision_tower_name)
97
- self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
98
-
99
- # self.vision_tower = clip_model.visual.trunk
100
- self.vision_tower.output_tokens = True
101
-
102
- self._hidden_size = self.vision_tower.config.hidden_size
103
-
104
- self.image_processor = SigLipImageProcessor()
105
-
106
- del self.vision_tower.vision_model.encoder.layers[-1:]
107
- self.vision_tower.vision_model.head = nn.Identity()
108
-
109
- self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
110
- self.is_loaded = True
111
-
112
- def _forward(self, images):
113
- with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
114
- image_features = self.vision_tower.forward(
115
- images.to(device=self.device, dtype=self.dtype),
116
- output_hidden_states=True,
117
- ).hidden_states[-1]
118
- return image_features
119
- @property
120
- def dummy_feature(self):
121
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
122
-
123
- @property
124
- def dtype(self):
125
- for p in self.vision_tower.parameters():
126
- return p.dtype
127
-
128
- @property
129
- def device(self):
130
- for p in self.vision_tower.parameters():
131
- return p.device
132
-
133
- @property
134
- def hidden_size(self):
135
- return self.config.hidden_size
136
-
137
- @property
138
- def num_patches(self):
139
- return (336 // 14) ** 2
140
-
141
- @property
142
- def num_patches_per_side(self):
143
- #return self.config.image_size // self.config.patch_size
144
- return 336//14
145
- #return 27
146
- # return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
147
-
148
- @property
149
- def image_size(self):
150
- return 384
151
- #return self.config.image_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multimodal_encoder/__pycache__/base_encoder.cpython-310.pyc DELETED
Binary file (2.62 kB)
 
multimodal_encoder/__pycache__/builder.cpython-310.pyc DELETED
Binary file (697 Bytes)
 
multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc DELETED
Binary file (6.53 kB)
 
multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc DELETED
Binary file (5.81 kB)
 
multimodal_encoder/base_encoder.py DELETED
@@ -1,68 +0,0 @@
1
- from abc import ABC, abstractmethod
2
-
3
- import torch
4
- import torch.nn as nn
5
-
6
-
7
- class BaseVisionTower(nn.Module):
8
- def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False):
9
- super().__init__()
10
-
11
- self.is_loaded = False
12
-
13
- self.vision_tower_name = vision_tower_name
14
- self.delay_load = delay_load
15
-
16
- @abstractmethod
17
- def load_model(self, device_map=None):
18
- raise NotImplementedError("Subclasses must implement load_model")
19
-
20
- @abstractmethod
21
- def _forward(self, images):
22
- raise NotImplementedError("Subclasses must implement forward")
23
-
24
- def forward(self, images):
25
- if type(images) is list:
26
- image_features = [self._forward(image.unsqueeze(0)) for image in images]
27
- else:
28
- image_features = self._forward(images)
29
-
30
- return image_features
31
-
32
- @property
33
- def dummy_feature(self):
34
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
35
-
36
- @property
37
- def dtype(self):
38
- # Dynamically infer the dtype from the first parameter, if not explicitly specified
39
- if hasattr(self.vision_tower, "dtype"):
40
- return self.vision_tower.dtype
41
- else:
42
- params = list(self.vision_tower.parameters())
43
- return (
44
- params[0].dtype if len(params) > 0 else torch.float32
45
- ) # Default to torch.float32 if no parameters
46
-
47
- @property
48
- def device(self):
49
- # Dynamically infer the device from the first parameter, if not explicitly specified
50
- if hasattr(self.vision_tower, "device"):
51
- return self.vision_tower.device
52
- else:
53
- params = list(self.vision_tower.parameters())
54
- return (
55
- params[0].device if len(params) > 0 else torch.device("cpu")
56
- ) # Default to CPU if no parameters
57
- @property
58
- def config(self):
59
- if self.is_loaded:
60
- return self.vision_tower.config
61
- else:
62
- return self.cfg_only
63
- @property
64
- def hidden_size(self):
65
- try:
66
- return self.config.hidden_size
67
- except:
68
- return self._hidden_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multimodal_encoder/builder.py DELETED
@@ -1,20 +0,0 @@
1
- import os
2
- from .siglip_encoder import SigLipVisionTower
3
- # from .eva_clip.eva_clip_encoder import EvaClipVisionTower
4
- # from .dev_eva_clip.eva_vit import EvaViTWrapper
5
-
6
-
7
- def build_vision_tower(vision_tower_cfg, **kwargs):
8
-
9
- vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
10
- is_absolute_path_exists = os.path.exists(vision_tower)
11
- use_s2 = getattr(vision_tower_cfg, "s2", False)
12
-
13
- #print(getattr(vision_tower_cfg, "vision_tower", None))
14
- return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
15
- if getattr(vision_tower_cfg, "vision_tower", None) and "siglip" in getattr(vision_tower_cfg, "vision_tower", None).lower():
16
- #print('*************\n')
17
- return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
18
-
19
-
20
- raise ValueError(f"Unknown vision tower: {vision_tower}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multimodal_projector/__pycache__/builder.cpython-310.pyc DELETED
Binary file (2.4 kB)
 
multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc DELETED
Binary file (1.47 kB)
 
multimodal_projector/pooler_projector.py DELETED
@@ -1,33 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- import math
5
-
6
- from transformers.models.clip.modeling_clip import CLIPVisionModel
7
-
8
-
9
- class PoolerProjector(nn.Module):
10
- def __init__(self, config, vision_cfg):
11
- super().__init__()
12
- self._config = config
13
- self.hw = vision_cfg.image_size // vision_cfg.patch_size
14
-
15
- self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2)
16
-
17
- self.proj = nn.Sequential(
18
- nn.GELU(),
19
- nn.Linear(config.hidden_size, config.hidden_size),
20
- )
21
-
22
- def forward(self, x, *args, **kwargs):
23
- height = width = self.hw
24
- assert height * width == x.shape[1]
25
- x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2)
26
- x = self.conv_pool(x)
27
- x = x.flatten(2).transpose(1, 2)
28
- x = self.proj(x)
29
- return x
30
-
31
- @property
32
- def config(self):
33
- return {"mm_projector_type": "pooler"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multimodal_resampler/__pycache__/builder.cpython-310.pyc DELETED
Binary file (1.45 kB)
 
multimodal_resampler/__pycache__/masked_drop.cpython-310.pyc DELETED
Binary file (2.47 kB)
 
multimodal_resampler/__pycache__/perceiver.cpython-310.pyc DELETED
Binary file (4.86 kB)
 
multimodal_resampler/__pycache__/qformer.cpython-310.pyc DELETED
Binary file (32.7 kB)
 
multimodal_resampler/__pycache__/spatial_pool.cpython-310.pyc DELETED
Binary file (1.9 kB)
 
multimodal_resampler/builder.py DELETED
@@ -1,34 +0,0 @@
1
- import torch
2
-
3
- from .masked_drop import MaskedDrop
4
- from .spatial_pool import SpatialPool
5
- from .perceiver import PerceiverResampler
6
- from .qformer import Qformer
7
-
8
-
9
- class IdentityMap(torch.nn.Module):
10
- def __init__(self):
11
- super().__init__()
12
-
13
- def forward(self, x, *args, **kwargs):
14
- return x
15
-
16
- @property
17
- def config(self):
18
- return {"mm_resampler_type": None}
19
-
20
-
21
- def build_vision_resampler(model_args, delay_load=False, **kwargs):
22
- resampler_type = getattr(model_args, "mm_resampler_type", None)
23
- if resampler_type == "masked_drop":
24
- return MaskedDrop(model_args)
25
- elif resampler_type == "spatial_pool":
26
- return SpatialPool(model_args, **kwargs)
27
- elif resampler_type == "perceiver":
28
- return PerceiverResampler(model_args, **kwargs)
29
- elif resampler_type == "qformer":
30
- return Qformer(model_args, **kwargs)
31
- elif resampler_type is None:
32
- return IdentityMap()
33
-
34
- raise ValueError(f"Unknown resampler type: {resampler_type}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multimodal_resampler/masked_drop.py DELETED
@@ -1,80 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- import random
5
-
6
-
7
- class MaskedDrop(nn.Module):
8
- def __init__(self, model_args):
9
- super().__init__()
10
-
11
- self.mode = model_args.mm_mask_drop_mode
12
- self.skip_percentage = model_args.mm_mask_drop_skip_percentage
13
- self.ratio = model_args.mm_mask_drop_ratio
14
- self.ratio_upper = model_args.mm_mask_drop_ratio_upper
15
- self.ratio_lower = model_args.mm_mask_drop_ratio_lower
16
-
17
- def forward(self, image_features, *args, **kwargs):
18
-
19
- if not self.training:
20
- return image_features
21
-
22
- if self.skip_percentage > random.random():
23
- return image_features
24
-
25
- masked_features = []
26
-
27
- for image_feature in image_features:
28
- num_tokens = image_feature.shape[0]
29
- if self.mode == "fixed":
30
- num_keep = int(num_tokens * self.ratio)
31
- masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0])
32
- elif self.mode == "range":
33
- num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper))
34
- masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0])
35
- elif self.mode == "cls_only":
36
- masked_features.append(image_feature[0:1])
37
- else:
38
- raise ValueError(f"Unexpected masked drop mode: {self.mode}")
39
-
40
- if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]):
41
- masked_features = torch.stack(masked_features, dim=0)
42
-
43
- return masked_features
44
-
45
- @property
46
- def config(self):
47
- return {
48
- "mm_resampler_type": "masked_drop",
49
- "mm_mask_drop_mode": self.mode,
50
- "mm_mask_drop_skip_percentage": self.skip_percentage,
51
- "mm_mask_drop_ratio": self.ratio,
52
- "mm_mask_drop_ratio_upper": self.ratio_upper,
53
- "mm_mask_drop_ratio_lower": self.ratio_lower,
54
- }
55
-
56
- def random_masking(self, x, len_keep):
57
- """
58
- Perform per-sample random masking by per-sample shuffling.
59
- Per-sample shuffling is done by argsort random noise.
60
- x: [N, L, D], sequence
61
- """
62
- N, L, D = x.shape # batch, length, dim
63
-
64
- noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
65
-
66
- # sort noise for each sample
67
- ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
68
- ids_restore = torch.argsort(ids_shuffle, dim=1)
69
-
70
- # keep the first subset
71
- ids_keep = ids_shuffle[:, :len_keep]
72
- x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
73
-
74
- # generate the binary mask: 0 is keep, 1 is remove
75
- mask = torch.ones([N, L], device=x.device)
76
- mask[:, :len_keep] = 0
77
- # unshuffle to get the binary mask
78
- mask = torch.gather(mask, dim=1, index=ids_restore)
79
-
80
- return x_masked, mask, ids_restore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multimodal_resampler/perceiver.py DELETED
@@ -1,155 +0,0 @@
1
- """
2
- Taken from https://github.com/lucidrains/flamingo-pytorch
3
- """
4
-
5
- import torch
6
- from einops import rearrange, repeat
7
-
8
- try:
9
- from einops_exts import rearrange_many
10
- except:
11
- pass
12
-
13
- from torch import einsum, nn
14
-
15
-
16
- def exists(val):
17
- return val is not None
18
-
19
-
20
- def FeedForward(dim, mult=4):
21
- inner_dim = int(dim * mult)
22
- return nn.Sequential(
23
- nn.LayerNorm(dim),
24
- nn.Linear(dim, inner_dim, bias=False),
25
- nn.GELU(),
26
- nn.Linear(inner_dim, dim, bias=False),
27
- )
28
-
29
-
30
- class PerceiverAttention(nn.Module):
31
- def __init__(self, *, dim, dim_head=64, heads=8):
32
- super().__init__()
33
- self.scale = dim_head**-0.5
34
- self.heads = heads
35
- inner_dim = dim_head * heads
36
-
37
- self.norm_media = nn.LayerNorm(dim)
38
- self.norm_latents = nn.LayerNorm(dim)
39
-
40
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
41
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
42
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
43
-
44
- def forward(self, x, latents):
45
- """
46
- Args:
47
- x (torch.Tensor): image features
48
- shape (b, T, n1, D)
49
- latent (torch.Tensor): latent features
50
- shape (b, T, n2, D)
51
- """
52
- x = self.norm_media(x)
53
- latents = self.norm_latents(latents)
54
-
55
- h = self.heads
56
-
57
- q = self.to_q(latents)
58
- kv_input = torch.cat((x, latents), dim=-2)
59
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
60
- q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
61
- q = q * self.scale
62
-
63
- # attention
64
- sim = einsum("... i d, ... j d -> ... i j", q, k)
65
- sim = sim - sim.amax(dim=-1, keepdim=True).detach()
66
- attn = sim.softmax(dim=-1)
67
-
68
- out = einsum("... i j, ... j d -> ... i d", attn, v)
69
- out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
70
- return self.to_out(out)
71
-
72
-
73
- class PerceiverResamplerModule(nn.Module):
74
- def __init__(
75
- self,
76
- *,
77
- dim,
78
- depth=6,
79
- dim_head=64,
80
- heads=8,
81
- num_latents=64,
82
- max_num_media=None,
83
- max_num_frames=None,
84
- ff_mult=4,
85
- ):
86
- super().__init__()
87
- self.latents = nn.Parameter(torch.randn(num_latents, dim))
88
- self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None
89
- self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None
90
-
91
- self.layers = nn.ModuleList([])
92
- for _ in range(depth):
93
- self.layers.append(
94
- nn.ModuleList(
95
- [
96
- PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
97
- FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(),
98
- ]
99
- )
100
- )
101
-
102
- self.norm = nn.LayerNorm(dim)
103
-
104
- def forward(self, x):
105
- """
106
- Args:
107
- x (torch.Tensor): image features
108
- shape (b, T, F, v, D)
109
- Returns:
110
- shape (b, T, n, D) where n is self.num_latents
111
- """
112
- b, T, F, v = x.shape[:4]
113
-
114
- # frame and media time embeddings
115
- if exists(self.frame_embs):
116
- frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
117
- x = x + frame_embs
118
- x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions
119
- if exists(self.media_time_embs):
120
- x = x + self.media_time_embs[:T]
121
-
122
- # blocks
123
- latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
124
- for attn, ff in self.layers:
125
- latents = attn(x, latents) + latents
126
- latents = ff(latents) + latents
127
- return self.norm(latents)
128
-
129
-
130
- class PerceiverResampler(nn.Module):
131
- def __init__(self, model_args, vision_tower):
132
- super().__init__()
133
-
134
- self.depth = model_args.mm_perceiver_depth
135
- self.num_latents = model_args.mm_perceiver_latents
136
- self.ff_mult = model_args.mm_perceiver_ff_mult
137
- self.pretrained = model_args.mm_perceiver_pretrained
138
-
139
- self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult)
140
-
141
- if self.pretrained is not None:
142
- self.load_state_dict(torch.load(self.pretrained))
143
-
144
- def forward(self, image_features, *args, **kwargs):
145
- return self.perceiver(image_features[:, None, None]).squeeze(1)
146
-
147
- @property
148
- def config(self):
149
- return {
150
- "mm_resampler_type": "perceiver",
151
- "mm_perceiver_depth": self.depth,
152
- "mm_perceiver_latents": self.num_latents,
153
- "mm_perceiver_ff_mult": self.ff_mult,
154
- "mm_perceiver_pretrained": self.pretrained,
155
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multimodal_resampler/qformer.py DELETED
@@ -1,1160 +0,0 @@
1
- """
2
- * Copyright (c) 2023, salesforce.com, inc.
3
- * All rights reserved.
4
- * SPDX-License-Identifier: BSD-3-Clause
5
- * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- * By Junnan Li
7
- * Based on huggingface code base
8
- * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
- """
10
-
11
- import math
12
- import os
13
- import warnings
14
- from dataclasses import dataclass
15
- from typing import Optional, Tuple, Dict, Any
16
-
17
- import torch
18
- from torch import Tensor, device, dtype, nn
19
- import torch.utils.checkpoint
20
- from torch import nn
21
- from torch.nn import CrossEntropyLoss
22
- import torch.nn.functional as F
23
-
24
- from transformers.activations import ACT2FN
25
- from transformers.file_utils import (
26
- ModelOutput,
27
- )
28
- from transformers.modeling_outputs import (
29
- BaseModelOutputWithPastAndCrossAttentions,
30
- BaseModelOutputWithPoolingAndCrossAttentions,
31
- CausalLMOutputWithCrossAttentions,
32
- MaskedLMOutput,
33
- MultipleChoiceModelOutput,
34
- NextSentencePredictorOutput,
35
- QuestionAnsweringModelOutput,
36
- SequenceClassifierOutput,
37
- TokenClassifierOutput,
38
- )
39
- from transformers.modeling_utils import (
40
- PreTrainedModel,
41
- apply_chunking_to_forward,
42
- find_pruneable_heads_and_indices,
43
- prune_linear_layer,
44
- )
45
- from transformers.utils import logging
46
- from transformers.models.bert.configuration_bert import BertConfig
47
-
48
- logger = logging.get_logger(__name__)
49
-
50
-
51
- def disabled_train(self, mode=True):
52
- """Overwrite model.train with this function to make sure train/eval mode
53
- does not change anymore."""
54
- return self
55
-
56
-
57
- class BertEmbeddings(nn.Module):
58
- """Construct the embeddings from word and position embeddings."""
59
-
60
- def __init__(self, config):
61
- super().__init__()
62
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
63
- self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
64
-
65
- # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
66
- # any TensorFlow checkpoint file
67
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
68
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
69
-
70
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
71
- self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
72
- self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
73
-
74
- self.config = config
75
-
76
- def forward(
77
- self,
78
- input_ids=None,
79
- position_ids=None,
80
- query_embeds=None,
81
- past_key_values_length=0,
82
- ):
83
- if input_ids is not None:
84
- seq_length = input_ids.size()[1]
85
- else:
86
- seq_length = 0
87
-
88
- if position_ids is None:
89
- position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
90
-
91
- if input_ids is not None:
92
- embeddings = self.word_embeddings(input_ids)
93
- if self.position_embedding_type == "absolute":
94
- position_embeddings = self.position_embeddings(position_ids)
95
- embeddings = embeddings + position_embeddings
96
-
97
- if query_embeds is not None:
98
- embeddings = torch.cat((query_embeds, embeddings), dim=1)
99
- else:
100
- embeddings = query_embeds
101
-
102
- embeddings = self.LayerNorm(embeddings)
103
- embeddings = self.dropout(embeddings)
104
- return embeddings
105
-
106
-
107
- class BertSelfAttention(nn.Module):
108
- def __init__(self, config, is_cross_attention):
109
- super().__init__()
110
- self.config = config
111
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
112
- raise ValueError("The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads))
113
-
114
- self.num_attention_heads = config.num_attention_heads
115
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
116
- self.all_head_size = self.num_attention_heads * self.attention_head_size
117
-
118
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
119
- if is_cross_attention:
120
- self.key = nn.Linear(config.encoder_width, self.all_head_size)
121
- self.value = nn.Linear(config.encoder_width, self.all_head_size)
122
- else:
123
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
124
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
125
-
126
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
127
- self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
128
- if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
129
- self.max_position_embeddings = config.max_position_embeddings
130
- self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
131
- self.save_attention = False
132
-
133
- def save_attn_gradients(self, attn_gradients):
134
- self.attn_gradients = attn_gradients
135
-
136
- def get_attn_gradients(self):
137
- return self.attn_gradients
138
-
139
- def save_attention_map(self, attention_map):
140
- self.attention_map = attention_map
141
-
142
- def get_attention_map(self):
143
- return self.attention_map
144
-
145
- def transpose_for_scores(self, x):
146
- new_x_shape = x.size()[:-1] + (
147
- self.num_attention_heads,
148
- self.attention_head_size,
149
- )
150
- x = x.view(*new_x_shape)
151
- return x.permute(0, 2, 1, 3)
152
-
153
- def forward(
154
- self,
155
- hidden_states,
156
- attention_mask=None,
157
- head_mask=None,
158
- encoder_hidden_states=None,
159
- encoder_attention_mask=None,
160
- past_key_value=None,
161
- output_attentions=False,
162
- ):
163
-
164
- # If this is instantiated as a cross-attention module, the keys
165
- # and values come from an encoder; the attention mask needs to be
166
- # such that the encoder's padding tokens are not attended to.
167
- is_cross_attention = encoder_hidden_states is not None
168
-
169
- if is_cross_attention:
170
- key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
171
- value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
172
- attention_mask = encoder_attention_mask
173
- elif past_key_value is not None:
174
- key_layer = self.transpose_for_scores(self.key(hidden_states))
175
- value_layer = self.transpose_for_scores(self.value(hidden_states))
176
- key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
177
- value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
178
- else:
179
- key_layer = self.transpose_for_scores(self.key(hidden_states))
180
- value_layer = self.transpose_for_scores(self.value(hidden_states))
181
-
182
- mixed_query_layer = self.query(hidden_states)
183
-
184
- query_layer = self.transpose_for_scores(mixed_query_layer)
185
-
186
- past_key_value = (key_layer, value_layer)
187
-
188
- # Take the dot product between "query" and "key" to get the raw attention scores.
189
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
190
-
191
- if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
192
- seq_length = hidden_states.size()[1]
193
- position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
194
- position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
195
- distance = position_ids_l - position_ids_r
196
- positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
197
- positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
198
-
199
- if self.position_embedding_type == "relative_key":
200
- relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
201
- attention_scores = attention_scores + relative_position_scores
202
- elif self.position_embedding_type == "relative_key_query":
203
- relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
204
- relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
205
- attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
206
-
207
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
208
- if attention_mask is not None:
209
- # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
210
- attention_scores = attention_scores + attention_mask
211
-
212
- # Normalize the attention scores to probabilities.
213
- attention_probs = nn.Softmax(dim=-1)(attention_scores)
214
-
215
- if is_cross_attention and self.save_attention:
216
- self.save_attention_map(attention_probs)
217
- attention_probs.register_hook(self.save_attn_gradients)
218
-
219
- # This is actually dropping out entire tokens to attend to, which might
220
- # seem a bit unusual, but is taken from the original Transformer paper.
221
- attention_probs_dropped = self.dropout(attention_probs)
222
-
223
- # Mask heads if we want to
224
- if head_mask is not None:
225
- attention_probs_dropped = attention_probs_dropped * head_mask
226
-
227
- context_layer = torch.matmul(attention_probs_dropped, value_layer)
228
-
229
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
230
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
231
- context_layer = context_layer.view(*new_context_layer_shape)
232
-
233
- outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
234
-
235
- outputs = outputs + (past_key_value,)
236
- return outputs
237
-
238
-
239
- class BertSelfOutput(nn.Module):
240
- def __init__(self, config):
241
- super().__init__()
242
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
243
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
244
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
245
-
246
- def forward(self, hidden_states, input_tensor):
247
- hidden_states = self.dense(hidden_states)
248
- hidden_states = self.dropout(hidden_states)
249
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
250
- return hidden_states
251
-
252
-
253
- class BertAttention(nn.Module):
254
- def __init__(self, config, is_cross_attention=False):
255
- super().__init__()
256
- self.self = BertSelfAttention(config, is_cross_attention)
257
- self.output = BertSelfOutput(config)
258
- self.pruned_heads = set()
259
-
260
- def prune_heads(self, heads):
261
- if len(heads) == 0:
262
- return
263
- heads, index = find_pruneable_heads_and_indices(
264
- heads,
265
- self.self.num_attention_heads,
266
- self.self.attention_head_size,
267
- self.pruned_heads,
268
- )
269
-
270
- # Prune linear layers
271
- self.self.query = prune_linear_layer(self.self.query, index)
272
- self.self.key = prune_linear_layer(self.self.key, index)
273
- self.self.value = prune_linear_layer(self.self.value, index)
274
- self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
275
-
276
- # Update hyper params and store pruned heads
277
- self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
278
- self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
279
- self.pruned_heads = self.pruned_heads.union(heads)
280
-
281
- def forward(
282
- self,
283
- hidden_states,
284
- attention_mask=None,
285
- head_mask=None,
286
- encoder_hidden_states=None,
287
- encoder_attention_mask=None,
288
- past_key_value=None,
289
- output_attentions=False,
290
- ):
291
- self_outputs = self.self(
292
- hidden_states,
293
- attention_mask,
294
- head_mask,
295
- encoder_hidden_states,
296
- encoder_attention_mask,
297
- past_key_value,
298
- output_attentions,
299
- )
300
- attention_output = self.output(self_outputs[0], hidden_states)
301
-
302
- outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
303
- return outputs
304
-
305
-
306
- class BertIntermediate(nn.Module):
307
- def __init__(self, config):
308
- super().__init__()
309
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
310
- if isinstance(config.hidden_act, str):
311
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
312
- else:
313
- self.intermediate_act_fn = config.hidden_act
314
-
315
- def forward(self, hidden_states):
316
- hidden_states = self.dense(hidden_states)
317
- hidden_states = self.intermediate_act_fn(hidden_states)
318
- return hidden_states
319
-
320
-
321
- class BertOutput(nn.Module):
322
- def __init__(self, config):
323
- super().__init__()
324
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
325
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
326
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
327
-
328
- def forward(self, hidden_states, input_tensor):
329
- hidden_states = self.dense(hidden_states)
330
- hidden_states = self.dropout(hidden_states)
331
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
332
- return hidden_states
333
-
334
-
335
- class BertLayer(nn.Module):
336
- def __init__(self, config, layer_num):
337
- super().__init__()
338
- self.config = config
339
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
340
- self.seq_len_dim = 1
341
- self.attention = BertAttention(config)
342
- self.layer_num = layer_num
343
- if self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0:
344
- self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
345
- self.has_cross_attention = True
346
- else:
347
- self.has_cross_attention = False
348
- self.intermediate = BertIntermediate(config)
349
- self.output = BertOutput(config)
350
-
351
- self.intermediate_query = BertIntermediate(config)
352
- self.output_query = BertOutput(config)
353
-
354
- def forward(
355
- self,
356
- hidden_states,
357
- attention_mask=None,
358
- head_mask=None,
359
- encoder_hidden_states=None,
360
- encoder_attention_mask=None,
361
- past_key_value=None,
362
- output_attentions=False,
363
- query_length=0,
364
- ):
365
- # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
366
- self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
367
- self_attention_outputs = self.attention(
368
- hidden_states,
369
- attention_mask,
370
- head_mask,
371
- output_attentions=output_attentions,
372
- past_key_value=self_attn_past_key_value,
373
- )
374
- attention_output = self_attention_outputs[0]
375
- outputs = self_attention_outputs[1:-1]
376
-
377
- present_key_value = self_attention_outputs[-1]
378
-
379
- if query_length > 0:
380
- query_attention_output = attention_output[:, :query_length, :]
381
-
382
- if self.has_cross_attention:
383
- assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
384
- cross_attention_outputs = self.crossattention(
385
- query_attention_output,
386
- attention_mask,
387
- head_mask,
388
- encoder_hidden_states,
389
- encoder_attention_mask,
390
- output_attentions=output_attentions,
391
- )
392
- query_attention_output = cross_attention_outputs[0]
393
- outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
394
-
395
- layer_output = apply_chunking_to_forward(
396
- self.feed_forward_chunk_query,
397
- self.chunk_size_feed_forward,
398
- self.seq_len_dim,
399
- query_attention_output,
400
- )
401
- if attention_output.shape[1] > query_length:
402
- layer_output_text = apply_chunking_to_forward(
403
- self.feed_forward_chunk,
404
- self.chunk_size_feed_forward,
405
- self.seq_len_dim,
406
- attention_output[:, query_length:, :],
407
- )
408
- layer_output = torch.cat([layer_output, layer_output_text], dim=1)
409
- else:
410
- layer_output = apply_chunking_to_forward(
411
- self.feed_forward_chunk,
412
- self.chunk_size_feed_forward,
413
- self.seq_len_dim,
414
- attention_output,
415
- )
416
- outputs = (layer_output,) + outputs
417
-
418
- outputs = outputs + (present_key_value,)
419
-
420
- return outputs
421
-
422
- def feed_forward_chunk(self, attention_output):
423
- intermediate_output = self.intermediate(attention_output)
424
- layer_output = self.output(intermediate_output, attention_output)
425
- return layer_output
426
-
427
- def feed_forward_chunk_query(self, attention_output):
428
- intermediate_output = self.intermediate_query(attention_output)
429
- layer_output = self.output_query(intermediate_output, attention_output)
430
- return layer_output
431
-
432
-
433
- class BertEncoder(nn.Module):
434
- def __init__(self, config):
435
- super().__init__()
436
- self.config = config
437
- self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)])
438
-
439
- def forward(
440
- self,
441
- hidden_states,
442
- attention_mask=None,
443
- head_mask=None,
444
- encoder_hidden_states=None,
445
- encoder_attention_mask=None,
446
- past_key_values=None,
447
- use_cache=None,
448
- output_attentions=False,
449
- output_hidden_states=False,
450
- return_dict=True,
451
- query_length=0,
452
- ):
453
- all_hidden_states = () if output_hidden_states else None
454
- all_self_attentions = () if output_attentions else None
455
- all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
456
-
457
- next_decoder_cache = () if use_cache else None
458
-
459
- for i in range(self.config.num_hidden_layers):
460
- layer_module = self.layer[i]
461
- if output_hidden_states:
462
- all_hidden_states = all_hidden_states + (hidden_states,)
463
-
464
- layer_head_mask = head_mask[i] if head_mask is not None else None
465
- past_key_value = past_key_values[i] if past_key_values is not None else None
466
-
467
- if getattr(self.config, "gradient_checkpointing", False) and self.training:
468
-
469
- if use_cache:
470
- logger.warn("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
471
- use_cache = False
472
-
473
- def create_custom_forward(module):
474
- def custom_forward(*inputs):
475
- return module(*inputs, past_key_value, output_attentions, query_length)
476
-
477
- return custom_forward
478
-
479
- layer_outputs = torch.utils.checkpoint.checkpoint(
480
- create_custom_forward(layer_module),
481
- hidden_states,
482
- attention_mask,
483
- layer_head_mask,
484
- encoder_hidden_states,
485
- encoder_attention_mask,
486
- )
487
- else:
488
- layer_outputs = layer_module(
489
- hidden_states,
490
- attention_mask,
491
- layer_head_mask,
492
- encoder_hidden_states,
493
- encoder_attention_mask,
494
- past_key_value,
495
- output_attentions,
496
- query_length,
497
- )
498
-
499
- hidden_states = layer_outputs[0]
500
- if use_cache:
501
- next_decoder_cache += (layer_outputs[-1],)
502
- if output_attentions:
503
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
504
- all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
505
-
506
- if output_hidden_states:
507
- all_hidden_states = all_hidden_states + (hidden_states,)
508
-
509
- if not return_dict:
510
- return tuple(
511
- v
512
- for v in [
513
- hidden_states,
514
- next_decoder_cache,
515
- all_hidden_states,
516
- all_self_attentions,
517
- all_cross_attentions,
518
- ]
519
- if v is not None
520
- )
521
- return BaseModelOutputWithPastAndCrossAttentions(
522
- last_hidden_state=hidden_states,
523
- past_key_values=next_decoder_cache,
524
- hidden_states=all_hidden_states,
525
- attentions=all_self_attentions,
526
- cross_attentions=all_cross_attentions,
527
- )
528
-
529
-
530
- class BertPooler(nn.Module):
531
- def __init__(self, config):
532
- super().__init__()
533
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
534
- self.activation = nn.Tanh()
535
-
536
- def forward(self, hidden_states):
537
- # We "pool" the model by simply taking the hidden state corresponding
538
- # to the first token.
539
- first_token_tensor = hidden_states[:, 0]
540
- pooled_output = self.dense(first_token_tensor)
541
- pooled_output = self.activation(pooled_output)
542
- return pooled_output
543
-
544
-
545
- class BertPredictionHeadTransform(nn.Module):
546
- def __init__(self, config):
547
- super().__init__()
548
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
549
- if isinstance(config.hidden_act, str):
550
- self.transform_act_fn = ACT2FN[config.hidden_act]
551
- else:
552
- self.transform_act_fn = config.hidden_act
553
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
554
-
555
- def forward(self, hidden_states):
556
- hidden_states = self.dense(hidden_states)
557
- hidden_states = self.transform_act_fn(hidden_states)
558
- hidden_states = self.LayerNorm(hidden_states)
559
- return hidden_states
560
-
561
-
562
- class BertLMPredictionHead(nn.Module):
563
- def __init__(self, config):
564
- super().__init__()
565
- self.transform = BertPredictionHeadTransform(config)
566
-
567
- # The output weights are the same as the input embeddings, but there is
568
- # an output-only bias for each token.
569
- self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
570
-
571
- self.bias = nn.Parameter(torch.zeros(config.vocab_size))
572
-
573
- # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
574
- self.decoder.bias = self.bias
575
-
576
- def forward(self, hidden_states):
577
- hidden_states = self.transform(hidden_states)
578
- hidden_states = self.decoder(hidden_states)
579
- return hidden_states
580
-
581
-
582
- class BertOnlyMLMHead(nn.Module):
583
- def __init__(self, config):
584
- super().__init__()
585
- self.predictions = BertLMPredictionHead(config)
586
-
587
- def forward(self, sequence_output):
588
- prediction_scores = self.predictions(sequence_output)
589
- return prediction_scores
590
-
591
-
592
- class BertPreTrainedModel(PreTrainedModel):
593
- """
594
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
595
- models.
596
- """
597
-
598
- config_class = BertConfig
599
- base_model_prefix = "bert"
600
- _keys_to_ignore_on_load_missing = [r"position_ids"]
601
-
602
- def _init_weights(self, module):
603
- """Initialize the weights"""
604
- if isinstance(module, (nn.Linear, nn.Embedding)):
605
- # Slightly different from the TF version which uses truncated_normal for initialization
606
- # cf https://github.com/pytorch/pytorch/pull/5617
607
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
608
- elif isinstance(module, nn.LayerNorm):
609
- module.bias.data.zero_()
610
- module.weight.data.fill_(1.0)
611
- if isinstance(module, nn.Linear) and module.bias is not None:
612
- module.bias.data.zero_()
613
-
614
-
615
- class BertModel(BertPreTrainedModel):
616
- """
617
- The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
618
- cross-attention is added between the self-attention layers, following the architecture described in `Attention is
619
- all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
620
- Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
621
- argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
622
- input to the forward pass.
623
- """
624
-
625
- def __init__(self, config, add_pooling_layer=False):
626
- super().__init__(config)
627
- self.config = config
628
-
629
- self.embeddings = BertEmbeddings(config)
630
-
631
- self.encoder = BertEncoder(config)
632
-
633
- self.pooler = BertPooler(config) if add_pooling_layer else None
634
-
635
- self.init_weights()
636
-
637
- def get_input_embeddings(self):
638
- return self.embeddings.word_embeddings
639
-
640
- def set_input_embeddings(self, value):
641
- self.embeddings.word_embeddings = value
642
-
643
- def _prune_heads(self, heads_to_prune):
644
- """
645
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
646
- class PreTrainedModel
647
- """
648
- for layer, heads in heads_to_prune.items():
649
- self.encoder.layer[layer].attention.prune_heads(heads)
650
-
651
- def get_extended_attention_mask(
652
- self,
653
- attention_mask: Tensor,
654
- input_shape: Tuple[int],
655
- device: device,
656
- is_decoder: bool,
657
- has_query: bool = False,
658
- ) -> Tensor:
659
- """
660
- Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
661
-
662
- Arguments:
663
- attention_mask (:obj:`torch.Tensor`):
664
- Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
665
- input_shape (:obj:`Tuple[int]`):
666
- The shape of the input to the model.
667
- device: (:obj:`torch.device`):
668
- The device of the input to the model.
669
-
670
- Returns:
671
- :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
672
- """
673
- # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
674
- # ourselves in which case we just need to make it broadcastable to all heads.
675
- if attention_mask.dim() == 3:
676
- extended_attention_mask = attention_mask[:, None, :, :]
677
- elif attention_mask.dim() == 2:
678
- # Provided a padding mask of dimensions [batch_size, seq_length]
679
- # - if the model is a decoder, apply a causal mask in addition to the padding mask
680
- # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
681
- if is_decoder:
682
- batch_size, seq_length = input_shape
683
-
684
- seq_ids = torch.arange(seq_length, device=device)
685
- causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
686
-
687
- # add a prefix ones mask to the causal mask
688
- # causal and attention masks must have same type with pytorch version < 1.3
689
- causal_mask = causal_mask.to(attention_mask.dtype)
690
-
691
- if causal_mask.shape[1] < attention_mask.shape[1]:
692
- prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
693
- if has_query: # UniLM style attention mask
694
- causal_mask = torch.cat(
695
- [
696
- torch.zeros(
697
- (batch_size, prefix_seq_len, seq_length),
698
- device=device,
699
- dtype=causal_mask.dtype,
700
- ),
701
- causal_mask,
702
- ],
703
- axis=1,
704
- )
705
- causal_mask = torch.cat(
706
- [
707
- torch.ones(
708
- (batch_size, causal_mask.shape[1], prefix_seq_len),
709
- device=device,
710
- dtype=causal_mask.dtype,
711
- ),
712
- causal_mask,
713
- ],
714
- axis=-1,
715
- )
716
- extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
717
- else:
718
- extended_attention_mask = attention_mask[:, None, None, :]
719
- else:
720
- raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape))
721
-
722
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
723
- # masked positions, this operation will create a tensor which is 0.0 for
724
- # positions we want to attend and -10000.0 for masked positions.
725
- # Since we are adding it to the raw scores before the softmax, this is
726
- # effectively the same as removing these entirely.
727
- extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
728
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
729
- return extended_attention_mask
730
-
731
- def forward(
732
- self,
733
- input_ids=None,
734
- attention_mask=None,
735
- position_ids=None,
736
- head_mask=None,
737
- query_embeds=None,
738
- encoder_hidden_states=None,
739
- encoder_attention_mask=None,
740
- past_key_values=None,
741
- use_cache=None,
742
- output_attentions=None,
743
- output_hidden_states=None,
744
- return_dict=None,
745
- is_decoder=False,
746
- ):
747
- r"""
748
- encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
749
- Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
750
- the model is configured as a decoder.
751
- encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
752
- Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
753
- the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
754
- - 1 for tokens that are **not masked**,
755
- - 0 for tokens that are **masked**.
756
- past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
757
- Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
758
- If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
759
- (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
760
- instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
761
- use_cache (:obj:`bool`, `optional`):
762
- If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
763
- decoding (see :obj:`past_key_values`).
764
- """
765
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
766
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
767
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
768
-
769
- # use_cache = use_cache if use_cache is not None else self.config.use_cache
770
-
771
- if input_ids is None:
772
- assert query_embeds is not None, "You have to specify query_embeds when input_ids is None"
773
-
774
- # past_key_values_length
775
- past_key_values_length = past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
776
-
777
- query_length = query_embeds.shape[1] if query_embeds is not None else 0
778
-
779
- embedding_output = self.embeddings(
780
- input_ids=input_ids,
781
- position_ids=position_ids,
782
- query_embeds=query_embeds,
783
- past_key_values_length=past_key_values_length,
784
- )
785
-
786
- input_shape = embedding_output.size()[:-1]
787
- batch_size, seq_length = input_shape
788
- device = embedding_output.device
789
-
790
- if attention_mask is None:
791
- attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
792
-
793
- # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
794
- # ourselves in which case we just need to make it broadcastable to all heads.
795
- if is_decoder:
796
- extended_attention_mask = self.get_extended_attention_mask(
797
- attention_mask,
798
- input_ids.shape,
799
- device,
800
- is_decoder,
801
- has_query=(query_embeds is not None),
802
- )
803
- else:
804
- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device, is_decoder)
805
-
806
- # If a 2D or 3D attention mask is provided for the cross-attention
807
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
808
- if encoder_hidden_states is not None:
809
- if type(encoder_hidden_states) == list:
810
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
811
- else:
812
- (
813
- encoder_batch_size,
814
- encoder_sequence_length,
815
- _,
816
- ) = encoder_hidden_states.size()
817
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
818
-
819
- if type(encoder_attention_mask) == list:
820
- encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
821
- elif encoder_attention_mask is None:
822
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
823
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
824
- else:
825
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
826
- else:
827
- encoder_extended_attention_mask = None
828
-
829
- # Prepare head mask if needed
830
- # 1.0 in head_mask indicate we keep the head
831
- # attention_probs has shape bsz x n_heads x N x N
832
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
833
- # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
834
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
835
-
836
- encoder_outputs = self.encoder(
837
- embedding_output,
838
- attention_mask=extended_attention_mask,
839
- head_mask=head_mask,
840
- encoder_hidden_states=encoder_hidden_states,
841
- encoder_attention_mask=encoder_extended_attention_mask,
842
- past_key_values=past_key_values,
843
- use_cache=use_cache,
844
- output_attentions=output_attentions,
845
- output_hidden_states=output_hidden_states,
846
- return_dict=return_dict,
847
- query_length=query_length,
848
- )
849
- sequence_output = encoder_outputs[0]
850
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
851
-
852
- if not return_dict:
853
- return (sequence_output, pooled_output) + encoder_outputs[1:]
854
-
855
- return BaseModelOutputWithPoolingAndCrossAttentions(
856
- last_hidden_state=sequence_output,
857
- pooler_output=pooled_output,
858
- past_key_values=encoder_outputs.past_key_values,
859
- hidden_states=encoder_outputs.hidden_states,
860
- attentions=encoder_outputs.attentions,
861
- cross_attentions=encoder_outputs.cross_attentions,
862
- )
863
-
864
-
865
- class BertLMHeadModel(BertPreTrainedModel):
866
-
867
- _keys_to_ignore_on_load_unexpected = [r"pooler"]
868
- _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
869
-
870
- def __init__(self, config):
871
- super().__init__(config)
872
-
873
- self.bert = BertModel(config, add_pooling_layer=False)
874
- self.cls = BertOnlyMLMHead(config)
875
-
876
- self.init_weights()
877
-
878
- def get_output_embeddings(self):
879
- return self.cls.predictions.decoder
880
-
881
- def set_output_embeddings(self, new_embeddings):
882
- self.cls.predictions.decoder = new_embeddings
883
-
884
- def forward(
885
- self,
886
- input_ids=None,
887
- attention_mask=None,
888
- position_ids=None,
889
- head_mask=None,
890
- query_embeds=None,
891
- encoder_hidden_states=None,
892
- encoder_attention_mask=None,
893
- labels=None,
894
- past_key_values=None,
895
- use_cache=True,
896
- output_attentions=None,
897
- output_hidden_states=None,
898
- return_dict=None,
899
- return_logits=False,
900
- is_decoder=True,
901
- reduction="mean",
902
- ):
903
- r"""
904
- encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
905
- Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
906
- the model is configured as a decoder.
907
- encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
908
- Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
909
- the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
910
- - 1 for tokens that are **not masked**,
911
- - 0 for tokens that are **masked**.
912
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
913
- Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
914
- ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
915
- ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
916
- past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
917
- Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
918
- If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
919
- (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
920
- instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
921
- use_cache (:obj:`bool`, `optional`):
922
- If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
923
- decoding (see :obj:`past_key_values`).
924
- Returns:
925
- Example::
926
- >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
927
- >>> import torch
928
- >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
929
- >>> config = BertConfig.from_pretrained("bert-base-cased")
930
- >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
931
- >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
932
- >>> outputs = model(**inputs)
933
- >>> prediction_logits = outputs.logits
934
- """
935
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
936
- if labels is not None:
937
- use_cache = False
938
- if past_key_values is not None:
939
- query_embeds = None
940
-
941
- outputs = self.bert(
942
- input_ids,
943
- attention_mask=attention_mask,
944
- position_ids=position_ids,
945
- head_mask=head_mask,
946
- query_embeds=query_embeds,
947
- encoder_hidden_states=encoder_hidden_states,
948
- encoder_attention_mask=encoder_attention_mask,
949
- past_key_values=past_key_values,
950
- use_cache=use_cache,
951
- output_attentions=output_attentions,
952
- output_hidden_states=output_hidden_states,
953
- return_dict=return_dict,
954
- is_decoder=is_decoder,
955
- )
956
-
957
- sequence_output = outputs[0]
958
- if query_embeds is not None:
959
- sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
960
-
961
- prediction_scores = self.cls(sequence_output)
962
-
963
- if return_logits:
964
- return prediction_scores[:, :-1, :].contiguous()
965
-
966
- lm_loss = None
967
- if labels is not None:
968
- # we are doing next-token prediction; shift prediction scores and input ids by one
969
- shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
970
- labels = labels[:, 1:].contiguous()
971
- loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
972
- lm_loss = loss_fct(
973
- shifted_prediction_scores.view(-1, self.config.vocab_size),
974
- labels.view(-1),
975
- )
976
- if reduction == "none":
977
- lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
978
-
979
- if not return_dict:
980
- output = (prediction_scores,) + outputs[2:]
981
- return ((lm_loss,) + output) if lm_loss is not None else output
982
-
983
- return CausalLMOutputWithCrossAttentions(
984
- loss=lm_loss,
985
- logits=prediction_scores,
986
- past_key_values=outputs.past_key_values,
987
- hidden_states=outputs.hidden_states,
988
- attentions=outputs.attentions,
989
- cross_attentions=outputs.cross_attentions,
990
- )
991
-
992
- def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs):
993
- # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
994
- if attention_mask is None:
995
- attention_mask = input_ids.new_ones(input_ids.shape)
996
- query_mask = input_ids.new_ones(query_embeds.shape[:-1])
997
- attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
998
-
999
- # cut decoder_input_ids if past is used
1000
- if past is not None:
1001
- input_ids = input_ids[:, -1:]
1002
-
1003
- return {
1004
- "input_ids": input_ids,
1005
- "query_embeds": query_embeds,
1006
- "attention_mask": attention_mask,
1007
- "past_key_values": past,
1008
- "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1009
- "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1010
- "is_decoder": True,
1011
- }
1012
-
1013
- def _reorder_cache(self, past, beam_idx):
1014
- reordered_past = ()
1015
- for layer_past in past:
1016
- reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1017
- return reordered_past
1018
-
1019
-
1020
- class BertForMaskedLM(BertPreTrainedModel):
1021
-
1022
- _keys_to_ignore_on_load_unexpected = [r"pooler"]
1023
- _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1024
-
1025
- def __init__(self, config):
1026
- super().__init__(config)
1027
-
1028
- self.bert = BertModel(config, add_pooling_layer=False)
1029
- self.cls = BertOnlyMLMHead(config)
1030
-
1031
- self.init_weights()
1032
-
1033
- def get_output_embeddings(self):
1034
- return self.cls.predictions.decoder
1035
-
1036
- def set_output_embeddings(self, new_embeddings):
1037
- self.cls.predictions.decoder = new_embeddings
1038
-
1039
- def forward(
1040
- self,
1041
- input_ids=None,
1042
- attention_mask=None,
1043
- position_ids=None,
1044
- head_mask=None,
1045
- query_embeds=None,
1046
- encoder_hidden_states=None,
1047
- encoder_attention_mask=None,
1048
- labels=None,
1049
- output_attentions=None,
1050
- output_hidden_states=None,
1051
- return_dict=None,
1052
- return_logits=False,
1053
- is_decoder=False,
1054
- ):
1055
- r"""
1056
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1057
- Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1058
- config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1059
- (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1060
- """
1061
-
1062
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1063
-
1064
- outputs = self.bert(
1065
- input_ids,
1066
- attention_mask=attention_mask,
1067
- position_ids=position_ids,
1068
- head_mask=head_mask,
1069
- query_embeds=query_embeds,
1070
- encoder_hidden_states=encoder_hidden_states,
1071
- encoder_attention_mask=encoder_attention_mask,
1072
- output_attentions=output_attentions,
1073
- output_hidden_states=output_hidden_states,
1074
- return_dict=return_dict,
1075
- is_decoder=is_decoder,
1076
- )
1077
-
1078
- if query_embeds is not None:
1079
- sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1080
- prediction_scores = self.cls(sequence_output)
1081
-
1082
- if return_logits:
1083
- return prediction_scores
1084
-
1085
- masked_lm_loss = None
1086
- if labels is not None:
1087
- loss_fct = CrossEntropyLoss() # -100 index = padding token
1088
- masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1089
-
1090
- if not return_dict:
1091
- output = (prediction_scores,) + outputs[2:]
1092
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1093
-
1094
- return MaskedLMOutput(
1095
- loss=masked_lm_loss,
1096
- logits=prediction_scores,
1097
- hidden_states=outputs.hidden_states,
1098
- attentions=outputs.attentions,
1099
- )
1100
-
1101
-
1102
- class Qformer(nn.Module):
1103
- def __init__(self, model_args, vision_tower):
1104
- super().__init__()
1105
-
1106
- self.depth = model_args.mm_qformer_depth
1107
- self.num_latents = model_args.mm_qformer_latents
1108
- self.pretrained = model_args.mm_qformer_pretrained
1109
-
1110
- self.Qformer, self.query_tokens, self.ln_vision = self.build_Qformer(vision_tower.hidden_size, self.depth, self.num_latents)
1111
-
1112
- if self.pretrained is not None:
1113
- pretrained_dict = torch.load(self.pretrained, map_location="cpu")["model"]
1114
- pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith("t5_proj")}
1115
- self.load_state_dict(pretrained_dict)
1116
-
1117
- def build_Qformer(self, vision_width, cross_attention_freq, num_query_token):
1118
- encoder_config = BertConfig.from_pretrained("bert-base-uncased")
1119
- encoder_config.encoder_width = vision_width
1120
- # insert cross-attention layer every other block
1121
- encoder_config.add_cross_attention = True
1122
- encoder_config.cross_attention_freq = cross_attention_freq
1123
- encoder_config.query_length = num_query_token
1124
- Qformer = BertLMHeadModel(config=encoder_config)
1125
- query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size))
1126
- query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
1127
- Qformer.cls = None
1128
- Qformer.bert.embeddings.word_embeddings = None
1129
- Qformer.bert.embeddings.position_embeddings = None
1130
- for layer in Qformer.bert.encoder.layer:
1131
- layer.output = None
1132
- layer.intermediate = None
1133
- return Qformer, query_tokens, nn.LayerNorm(vision_width)
1134
-
1135
- def forward(self, image_features, *args, **kwargs):
1136
- x = self.ln_vision(image_features)
1137
- image_atts = torch.ones(x.size()[:-1], dtype=torch.long).to(x.device)
1138
-
1139
- query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
1140
- query_output = self.Qformer.bert(
1141
- query_embeds=query_tokens,
1142
- encoder_hidden_states=x,
1143
- encoder_attention_mask=image_atts,
1144
- return_dict=True,
1145
- )
1146
-
1147
- return query_output.last_hidden_state
1148
-
1149
- @property
1150
- def hidden_size(self):
1151
- return 768
1152
-
1153
- @property
1154
- def config(self):
1155
- return {
1156
- "mm_resampler_type": "qformer",
1157
- "mm_qformer_depth": self.depth,
1158
- "mm_qformer_latents": self.num_latents,
1159
- "mm_qformer_pretrained": self.pretrained,
1160
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sae.py CHANGED
@@ -1,8 +1,1440 @@
1
  import torch
2
-
3
- from .sae_utils import SamePadConv3d,Normalize,SiLU,TemporalAttention,AttnBlock3D,MultiHeadAttention3D,TemporalAttention_lin
4
  import torch.nn as nn
5
  import pdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  class SiglipAE(nn.Module):
8
  def __init__(self):
@@ -34,12 +1466,4 @@ class SiglipAE(nn.Module):
34
 
35
  x=self.encoder(x)
36
  return x
37
- # image=torch.randn(1,1152,4,24,24).to('cuda')
38
-
39
-
40
- # model = SiglipAE().to('cuda')
41
- # model.load_state_dict(torch.load('encoder.pth'),strict=False)
42
-
43
- # image=model(image)
44
 
45
- # print(image.shape)
 
1
  import torch
 
 
2
  import torch.nn as nn
3
  import pdb
4
+ import math
5
+ from transformers.activations import ACT2FN
6
+ from einops import rearrange, reduce, repeat
7
+ from inspect import isfunction
8
+ import math
9
+ import torch.nn.functional as F
10
+ from torch import nn, einsum
11
+ from einops import rearrange, repeat
12
+ from typing import Optional, Any
13
+
14
+ try:
15
+ import xformers
16
+ import xformers.ops
17
+
18
+ XFORMERS_IS_AVAILBLE = True
19
+ except:
20
+ XFORMERS_IS_AVAILBLE = False
21
+
22
+ import importlib
23
+ import numpy as np
24
+ import cv2, os
25
+ import torch.distributed as dist
26
+
27
+
28
+ def count_params(model, verbose=False):
29
+ total_params = sum(p.numel() for p in model.parameters())
30
+ if verbose:
31
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
32
+ return total_params
33
+
34
+
35
+ def check_istarget(name, para_list):
36
+ """
37
+ name: full name of source para
38
+ para_list: partial name of target para
39
+ """
40
+ istarget = False
41
+ for para in para_list:
42
+ if para in name:
43
+ return True
44
+ return istarget
45
+
46
+
47
+ def instantiate_from_config(config):
48
+ if not "target" in config:
49
+ if config == "__is_first_stage__":
50
+ return None
51
+ elif config == "__is_unconditional__":
52
+ return None
53
+ raise KeyError("Expected key `target` to instantiate.")
54
+
55
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
56
+
57
+
58
+ def get_obj_from_str(string, reload=False):
59
+ module, cls = string.rsplit(".", 1)
60
+ if reload:
61
+ module_imp = importlib.import_module(module)
62
+ importlib.reload(module_imp)
63
+ return getattr(importlib.import_module(module, package=None), cls)
64
+
65
+
66
+ def load_npz_from_dir(data_dir):
67
+ data = [
68
+ np.load(os.path.join(data_dir, data_name))["arr_0"]
69
+ for data_name in os.listdir(data_dir)
70
+ ]
71
+ data = np.concatenate(data, axis=0)
72
+ return data
73
+
74
+
75
+ def load_npz_from_paths(data_paths):
76
+ data = [np.load(data_path)["arr_0"] for data_path in data_paths]
77
+ data = np.concatenate(data, axis=0)
78
+ return data
79
+
80
+
81
+ def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None):
82
+ h, w = image.shape[:2]
83
+ if resize_short_edge is not None:
84
+ k = resize_short_edge / min(h, w)
85
+ else:
86
+ k = max_resolution / (h * w)
87
+ k = k**0.5
88
+ h = int(np.round(h * k / 64)) * 64
89
+ w = int(np.round(w * k / 64)) * 64
90
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
91
+ return image
92
+
93
+
94
+ def setup_dist(args):
95
+ if dist.is_initialized():
96
+ return
97
+ torch.cuda.set_device(args.local_rank)
98
+ torch.distributed.init_process_group("nccl", init_method="env://")
99
+
100
+
101
+ # adopted from
102
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
103
+ # and
104
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
105
+ # and
106
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
107
+ #
108
+ # thanks!
109
+
110
+ import torch.nn as nn
111
+ import math
112
+ from inspect import isfunction
113
+ import torch
114
+ from torch import nn
115
+ import torch.distributed as dist
116
+
117
+
118
+ def gather_data(data, return_np=True):
119
+ """gather data from multiple processes to one list"""
120
+ data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
121
+ dist.all_gather(data_list, data) # gather not supported with NCCL
122
+ if return_np:
123
+ data_list = [data.cpu().numpy() for data in data_list]
124
+ return data_list
125
+
126
+
127
+ def autocast(f):
128
+ def do_autocast(*args, **kwargs):
129
+ with torch.cuda.amp.autocast(
130
+ enabled=True,
131
+ dtype=torch.get_autocast_gpu_dtype(),
132
+ cache_enabled=torch.is_autocast_cache_enabled(),
133
+ ):
134
+ return f(*args, **kwargs)
135
+
136
+ return do_autocast
137
+
138
+
139
+ def extract_into_tensor(a, t, x_shape):
140
+ b, *_ = t.shape
141
+ out = a.gather(-1, t)
142
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
143
+
144
+
145
+ def noise_like(shape, device, repeat=False):
146
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
147
+ shape[0], *((1,) * (len(shape) - 1))
148
+ )
149
+ noise = lambda: torch.randn(shape, device=device)
150
+ return repeat_noise() if repeat else noise()
151
+
152
+
153
+ def default(val, d):
154
+ if exists(val):
155
+ return val
156
+ return d() if isfunction(d) else d
157
+
158
+
159
+ def exists(val):
160
+ return val is not None
161
+
162
+
163
+ def identity(*args, **kwargs):
164
+ return nn.Identity()
165
+
166
+
167
+ def uniq(arr):
168
+ return {el: True for el in arr}.keys()
169
+
170
+
171
+ def mean_flat(tensor):
172
+ """
173
+ Take the mean over all non-batch dimensions.
174
+ """
175
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
176
+
177
+
178
+ def ismap(x):
179
+ if not isinstance(x, torch.Tensor):
180
+ return False
181
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
182
+
183
+
184
+ def isimage(x):
185
+ if not isinstance(x, torch.Tensor):
186
+ return False
187
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
188
+
189
+
190
+ def max_neg_value(t):
191
+ return -torch.finfo(t.dtype).max
192
+
193
+
194
+ def shape_to_str(x):
195
+ shape_str = "x".join([str(x) for x in x.shape])
196
+ return shape_str
197
+
198
+
199
+ def init_(tensor):
200
+ dim = tensor.shape[-1]
201
+ std = 1 / math.sqrt(dim)
202
+ tensor.uniform_(-std, std)
203
+ return tensor
204
+
205
+
206
+
207
+ def disabled_train(self, mode=True):
208
+ """Overwrite model.train with this function to make sure train/eval mode
209
+ does not change anymore."""
210
+ return self
211
+
212
+
213
+ def zero_module(module):
214
+ """
215
+ Zero out the parameters of a module and return it.
216
+ """
217
+ for p in module.parameters():
218
+ p.detach().zero_()
219
+ return module
220
+
221
+
222
+ def scale_module(module, scale):
223
+ """
224
+ Scale the parameters of a module and return it.
225
+ """
226
+ for p in module.parameters():
227
+ p.detach().mul_(scale)
228
+ return module
229
+
230
+
231
+ def conv_nd(dims, *args, **kwargs):
232
+ """
233
+ Create a 1D, 2D, or 3D convolution module.
234
+ """
235
+ if dims == 1:
236
+ return nn.Conv1d(*args, **kwargs)
237
+ elif dims == 2:
238
+ return nn.Conv2d(*args, **kwargs)
239
+ elif dims == 3:
240
+ return nn.Conv3d(*args, **kwargs)
241
+ raise ValueError(f"unsupported dimensions: {dims}")
242
+
243
+
244
+ def linear(*args, **kwargs):
245
+ """
246
+ Create a linear module.
247
+ """
248
+ return nn.Linear(*args, **kwargs)
249
+
250
+
251
+ def avg_pool_nd(dims, *args, **kwargs):
252
+ """
253
+ Create a 1D, 2D, or 3D average pooling module.
254
+ """
255
+ if dims == 1:
256
+ return nn.AvgPool1d(*args, **kwargs)
257
+ elif dims == 2:
258
+ return nn.AvgPool2d(*args, **kwargs)
259
+ elif dims == 3:
260
+ return nn.AvgPool3d(*args, **kwargs)
261
+ raise ValueError(f"unsupported dimensions: {dims}")
262
+
263
+
264
+ def nonlinearity(type="silu"):
265
+ if type == "silu":
266
+ return nn.SiLU()
267
+ elif type == "leaky_relu":
268
+ return nn.LeakyReLU()
269
+
270
+
271
+ class GroupNormSpecific(nn.GroupNorm):
272
+ def forward(self, x):
273
+ if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
274
+ return super().forward(x).type(x.dtype)
275
+ else:
276
+ return super().forward(x.float()).type(x.dtype)
277
+
278
+
279
+ def normalization(channels, num_groups=32):
280
+ """
281
+ Make a standard normalization layer.
282
+ :param channels: number of input channels.
283
+ :return: an nn.Module for normalization.
284
+ """
285
+ return GroupNormSpecific(num_groups, channels)
286
+
287
+
288
+ class HybridConditioner(nn.Module):
289
+
290
+ def __init__(self, c_concat_config, c_crossattn_config):
291
+ super().__init__()
292
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
293
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
294
+
295
+ def forward(self, c_concat, c_crossattn):
296
+ c_concat = self.concat_conditioner(c_concat)
297
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
298
+ return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
299
+
300
+ def exists(val):
301
+ return val is not None
302
+
303
+
304
+ def uniq(arr):
305
+ return {el: True for el in arr}.keys()
306
+
307
+
308
+ def default(val, d):
309
+ if exists(val):
310
+ return val
311
+ return d() if isfunction(d) else d
312
+
313
+
314
+ def max_neg_value(t):
315
+ return -torch.finfo(t.dtype).max
316
+
317
+
318
+ def init_(tensor):
319
+ dim = tensor.shape[-1]
320
+ std = 1 / math.sqrt(dim)
321
+ tensor.uniform_(-std, std)
322
+ return tensor
323
+
324
+
325
+ # feedforward
326
+ class GEGLU(nn.Module):
327
+ def __init__(self, dim_in, dim_out):
328
+ super().__init__()
329
+ self.proj = nn.Linear(dim_in, dim_out * 2)
330
+
331
+ def forward(self, x):
332
+ x, gate = self.proj(x).chunk(2, dim=-1)
333
+ return x * F.gelu(gate)
334
+
335
+
336
+ class FeedForward(nn.Module):
337
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
338
+ super().__init__()
339
+ inner_dim = int(dim * mult)
340
+ dim_out = default(dim_out, dim)
341
+ project_in = (
342
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
343
+ if not glu
344
+ else GEGLU(dim, inner_dim)
345
+ )
346
+
347
+ self.net = nn.Sequential(
348
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
349
+ )
350
+
351
+ def forward(self, x):
352
+ return self.net(x)
353
+
354
+
355
+ def zero_module(module):
356
+ """
357
+ Zero out the parameters of a module and return it.
358
+ """
359
+ for p in module.parameters():
360
+ p.detach().zero_()
361
+ return module
362
+
363
+
364
+ def Normalize(in_channels, num_groups=32):
365
+ return torch.nn.GroupNorm(
366
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
367
+ )
368
+
369
+
370
+ class RelativePosition(nn.Module):
371
+ """https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py"""
372
+
373
+ def __init__(self, num_units, max_relative_position):
374
+ super().__init__()
375
+ self.num_units = num_units
376
+ self.max_relative_position = max_relative_position
377
+ self.embeddings_table = nn.Parameter(
378
+ torch.Tensor(max_relative_position * 2 + 1, num_units)
379
+ )
380
+ nn.init.xavier_uniform_(self.embeddings_table)
381
+
382
+ def forward(self, length_q, length_k):
383
+ device = self.embeddings_table.device
384
+ range_vec_q = torch.arange(length_q, device=device)
385
+ range_vec_k = torch.arange(length_k, device=device)
386
+ distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
387
+ distance_mat_clipped = torch.clamp(
388
+ distance_mat, -self.max_relative_position, self.max_relative_position
389
+ )
390
+ final_mat = distance_mat_clipped + self.max_relative_position
391
+ # final_mat = torch.LongTensor(final_mat).to(self.embeddings_table.device)
392
+ # final_mat = torch.tensor(final_mat, device=self.embeddings_table.device, dtype=torch.long)
393
+ final_mat = final_mat.long()
394
+ embeddings = self.embeddings_table[final_mat]
395
+ return embeddings
396
+
397
+
398
+ class TemporalCrossAttention(nn.Module):
399
+ def __init__(
400
+ self,
401
+ query_dim,
402
+ context_dim=None,
403
+ heads=8,
404
+ dim_head=64,
405
+ dropout=0.0,
406
+ temporal_length=None, # For relative positional representation and image-video joint training.
407
+ image_length=None, # For image-video joint training.
408
+ use_relative_position=False, # whether use relative positional representation in temporal attention.
409
+ img_video_joint_train=False, # For image-video joint training.
410
+ use_tempoal_causal_attn=False,
411
+ bidirectional_causal_attn=False,
412
+ tempoal_attn_type=None,
413
+ joint_train_mode="same_batch",
414
+ **kwargs,
415
+ ):
416
+ super().__init__()
417
+ inner_dim = dim_head * heads
418
+ context_dim = default(context_dim, query_dim)
419
+ self.context_dim = context_dim
420
+
421
+ self.scale = dim_head**-0.5
422
+ self.heads = heads
423
+ self.temporal_length = temporal_length
424
+ self.use_relative_position = use_relative_position
425
+ self.img_video_joint_train = img_video_joint_train
426
+ self.bidirectional_causal_attn = bidirectional_causal_attn
427
+ self.joint_train_mode = joint_train_mode
428
+ assert joint_train_mode in ["same_batch", "diff_batch"]
429
+ self.tempoal_attn_type = tempoal_attn_type
430
+
431
+ if bidirectional_causal_attn:
432
+ assert use_tempoal_causal_attn
433
+ if tempoal_attn_type:
434
+ assert tempoal_attn_type in ["sparse_causal", "sparse_causal_first"]
435
+ assert not use_tempoal_causal_attn
436
+ assert not (
437
+ img_video_joint_train and (self.joint_train_mode == "same_batch")
438
+ )
439
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
440
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
441
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
442
+
443
+ assert not (
444
+ img_video_joint_train
445
+ and (self.joint_train_mode == "same_batch")
446
+ and use_tempoal_causal_attn
447
+ )
448
+ if img_video_joint_train:
449
+ if self.joint_train_mode == "same_batch":
450
+ mask = torch.ones(
451
+ [1, temporal_length + image_length, temporal_length + image_length]
452
+ )
453
+ # mask[:, image_length:, :] = 0
454
+ # mask[:, :, image_length:] = 0
455
+ mask[:, temporal_length:, :] = 0
456
+ mask[:, :, temporal_length:] = 0
457
+ self.mask = mask
458
+ else:
459
+ self.mask = None
460
+ elif use_tempoal_causal_attn:
461
+ # normal causal attn
462
+ self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length]))
463
+ elif tempoal_attn_type == "sparse_causal":
464
+ # all frames interact with only the `prev` & self frame
465
+ mask1 = torch.tril(
466
+ torch.ones([1, temporal_length, temporal_length])
467
+ ).bool() # true indicates keeping
468
+ mask2 = torch.zeros(
469
+ [1, temporal_length, temporal_length]
470
+ ) # initialize to same shape with mask1
471
+ mask2[:, 2:temporal_length, : temporal_length - 2] = torch.tril(
472
+ torch.ones([1, temporal_length - 2, temporal_length - 2])
473
+ )
474
+ mask2 = (1 - mask2).bool() # false indicates masking
475
+ self.mask = mask1 & mask2
476
+ elif tempoal_attn_type == "sparse_causal_first":
477
+ # all frames interact with only the `first` & self frame
478
+ mask1 = torch.tril(
479
+ torch.ones([1, temporal_length, temporal_length])
480
+ ).bool() # true indicates keeping
481
+ mask2 = torch.zeros([1, temporal_length, temporal_length])
482
+ mask2[:, 2:temporal_length, 1 : temporal_length - 1] = torch.tril(
483
+ torch.ones([1, temporal_length - 2, temporal_length - 2])
484
+ )
485
+ mask2 = (1 - mask2).bool() # false indicates masking
486
+ self.mask = mask1 & mask2
487
+ else:
488
+ self.mask = None
489
+
490
+ if use_relative_position:
491
+ assert temporal_length is not None
492
+ self.relative_position_k = RelativePosition(
493
+ num_units=dim_head, max_relative_position=temporal_length
494
+ )
495
+ self.relative_position_v = RelativePosition(
496
+ num_units=dim_head, max_relative_position=temporal_length
497
+ )
498
+
499
+ self.to_out = nn.Sequential(
500
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
501
+ )
502
+
503
+ nn.init.constant_(self.to_q.weight, 0)
504
+ nn.init.constant_(self.to_k.weight, 0)
505
+ nn.init.constant_(self.to_v.weight, 0)
506
+ nn.init.constant_(self.to_out[0].weight, 0)
507
+ nn.init.constant_(self.to_out[0].bias, 0)
508
+
509
+ def forward(self, x, context=None, mask=None):
510
+ # if context is None:
511
+ # print(f'[Temp Attn] x={x.shape},context=None')
512
+ # else:
513
+ # print(f'[Temp Attn] x={x.shape},context={context.shape}')
514
+
515
+ nh = self.heads
516
+ out = x
517
+ q = self.to_q(out)
518
+ # if context is not None:
519
+ # print(f'temporal context 1 ={context.shape}')
520
+ # print(f'x={x.shape}')
521
+ context = default(context, x)
522
+ # print(f'temporal context 2 ={context.shape}')
523
+ k = self.to_k(context)
524
+ v = self.to_v(context)
525
+ # print(f'q ={q.shape},k={k.shape}')
526
+
527
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=nh), (q, k, v))
528
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
529
+
530
+ if self.use_relative_position:
531
+ len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
532
+ k2 = self.relative_position_k(len_q, len_k)
533
+ sim2 = einsum("b t d, t s d -> b t s", q, k2) * self.scale # TODO check
534
+ sim += sim2
535
+ # print('mask',mask)
536
+ if exists(self.mask):
537
+ if mask is None:
538
+ mask = self.mask.to(sim.device)
539
+ else:
540
+ mask = self.mask.to(sim.device).bool() & mask # .to(sim.device)
541
+ else:
542
+ mask = mask
543
+ # if self.img_video_joint_train:
544
+ # # process mask (make mask same shape with sim)
545
+ # c, h, w = mask.shape
546
+ # c, t, s = sim.shape
547
+ # # assert(h == w and t == s),f"mask={mask.shape}, sim={sim.shape}, h={h}, w={w}, t={t}, s={s}"
548
+
549
+ # if h > t:
550
+ # mask = mask[:, :t, :]
551
+ # elif h < t: # pad zeros to mask (no attention) only initial mask =1 area compute weights
552
+ # mask_ = torch.zeros([c,t,w]).to(mask.device)
553
+ # mask_[:, :h, :] = mask
554
+ # mask = mask_
555
+ # c, h, w = mask.shape
556
+ # if w > s:
557
+ # mask = mask[:, :, :s]
558
+ # elif w < s: # pad zeros to mask
559
+ # mask_ = torch.zeros([c,h,s]).to(mask.device)
560
+ # mask_[:, :, :w] = mask
561
+ # mask = mask_
562
+
563
+ # max_neg_value = -torch.finfo(sim.dtype).max
564
+ # sim = sim.float().masked_fill(mask == 0, max_neg_value)
565
+ if mask is not None:
566
+ max_neg_value = -1e9
567
+ sim = sim + (1 - mask.float()) * max_neg_value # 1=masking,0=no masking
568
+ # print('sim after masking: ', sim)
569
+
570
+ # if torch.isnan(sim).any() or torch.isinf(sim).any() or (not sim.any()):
571
+ # print(f'sim [after masking], isnan={torch.isnan(sim).any()}, isinf={torch.isinf(sim).any()}, allzero={not sim.any()}')
572
+
573
+ attn = sim.softmax(dim=-1)
574
+ # print('attn after softmax: ', attn)
575
+ # if torch.isnan(attn).any() or torch.isinf(attn).any() or (not attn.any()):
576
+ # print(f'attn [after softmax], isnan={torch.isnan(attn).any()}, isinf={torch.isinf(attn).any()}, allzero={not attn.any()}')
577
+
578
+ # attn = torch.where(torch.isnan(attn), torch.full_like(attn,0), attn)
579
+ # if torch.isinf(attn.detach()).any():
580
+ # import pdb;pdb.set_trace()
581
+ # if torch.isnan(attn.detach()).any():
582
+ # import pdb;pdb.set_trace()
583
+ out = einsum("b i j, b j d -> b i d", attn, v)
584
+
585
+ if self.bidirectional_causal_attn:
586
+ mask_reverse = torch.triu(
587
+ torch.ones(
588
+ [1, self.temporal_length, self.temporal_length], device=sim.device
589
+ )
590
+ )
591
+ sim_reverse = sim.float().masked_fill(mask_reverse == 0, max_neg_value)
592
+ attn_reverse = sim_reverse.softmax(dim=-1)
593
+ out_reverse = einsum("b i j, b j d -> b i d", attn_reverse, v)
594
+ out += out_reverse
595
+
596
+ if self.use_relative_position:
597
+ v2 = self.relative_position_v(len_q, len_v)
598
+ out2 = einsum("b t s, t s d -> b t d", attn, v2) # TODO check
599
+ out += out2 # TODO check:先add还是先merge head?先计算rpr,on split head之后的数据,然后再merge。
600
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=nh) # merge head
601
+ return self.to_out(out)
602
+
603
+
604
+ class SpatialSelfAttention(nn.Module):
605
+ def __init__(self, in_channels):
606
+ super().__init__()
607
+ self.in_channels = in_channels
608
+
609
+ self.norm = Normalize(in_channels)
610
+ self.q = torch.nn.Conv2d(
611
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
612
+ )
613
+ self.k = torch.nn.Conv2d(
614
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
615
+ )
616
+ self.v = torch.nn.Conv2d(
617
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
618
+ )
619
+ self.proj_out = torch.nn.Conv2d(
620
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
621
+ )
622
+
623
+ def forward(self, x):
624
+ h_ = x
625
+ h_ = self.norm(h_)
626
+ q = self.q(h_)
627
+ k = self.k(h_)
628
+ v = self.v(h_)
629
+
630
+ # compute attention
631
+ b, c, h, w = q.shape
632
+ q = rearrange(q, "b c h w -> b (h w) c")
633
+ k = rearrange(k, "b c h w -> b c (h w)")
634
+ w_ = torch.einsum("bij,bjk->bik", q, k)
635
+
636
+ w_ = w_ * (int(c) ** (-0.5))
637
+ w_ = torch.nn.functional.softmax(w_, dim=2)
638
+
639
+ # attend to values
640
+ v = rearrange(v, "b c h w -> b c (h w)")
641
+ w_ = rearrange(w_, "b i j -> b j i")
642
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
643
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
644
+ h_ = self.proj_out(h_)
645
+
646
+ return x + h_
647
+
648
+
649
+ class CrossAttention(nn.Module):
650
+ def __init__(
651
+ self,
652
+ query_dim,
653
+ context_dim=None,
654
+ heads=8,
655
+ dim_head=64,
656
+ dropout=0.0,
657
+ sa_shared_kv=False,
658
+ shared_type="only_first",
659
+ **kwargs,
660
+ ):
661
+ super().__init__()
662
+ inner_dim = dim_head * heads
663
+ context_dim = default(context_dim, query_dim)
664
+ self.sa_shared_kv = sa_shared_kv
665
+ assert shared_type in [
666
+ "only_first",
667
+ "all_frames",
668
+ "first_and_prev",
669
+ "only_prev",
670
+ "full",
671
+ "causal",
672
+ "full_qkv",
673
+ ]
674
+ self.shared_type = shared_type
675
+
676
+ self.scale = dim_head**-0.5
677
+ self.heads = heads
678
+ self.dim_head = dim_head
679
+
680
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
681
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
682
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
683
+
684
+ self.to_out = nn.Sequential(
685
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
686
+ )
687
+ self.attention_op: Optional[Any] = None
688
+
689
+ def forward(self, x, context=None, mask=None):
690
+ h = self.heads
691
+ b = x.shape[0]
692
+
693
+ q = self.to_q(x)
694
+ context = default(context, x)
695
+ k = self.to_k(context)
696
+ v = self.to_v(context)
697
+ if self.sa_shared_kv:
698
+ if self.shared_type == "only_first":
699
+ k, v = map(
700
+ lambda xx: rearrange(xx[0].unsqueeze(0), "b n c -> (b n) c")
701
+ .unsqueeze(0)
702
+ .repeat(b, 1, 1),
703
+ (k, v),
704
+ )
705
+ else:
706
+ raise NotImplementedError
707
+
708
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
709
+
710
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
711
+
712
+ if exists(mask):
713
+ mask = rearrange(mask, "b ... -> b (...)")
714
+ max_neg_value = -torch.finfo(sim.dtype).max
715
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
716
+ sim.masked_fill_(~mask, max_neg_value)
717
+
718
+ # attention, what we cannot get enough of
719
+ attn = sim.softmax(dim=-1)
720
+
721
+ out = einsum("b i j, b j d -> b i d", attn, v)
722
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
723
+ return self.to_out(out)
724
+
725
+ def efficient_forward(self, x, context=None, mask=None):
726
+ q = self.to_q(x)
727
+ context = default(context, x)
728
+ k = self.to_k(context)
729
+ v = self.to_v(context)
730
+
731
+ b, _, _ = q.shape
732
+ q, k, v = map(
733
+ lambda t: t.unsqueeze(3)
734
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
735
+ .permute(0, 2, 1, 3)
736
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
737
+ .contiguous(),
738
+ (q, k, v),
739
+ )
740
+ # actually compute the attention, what we cannot get enough of
741
+ out = xformers.ops.memory_efficient_attention(
742
+ q, k, v, attn_bias=None, op=self.attention_op
743
+ )
744
+
745
+ if exists(mask):
746
+ raise NotImplementedError
747
+ out = (
748
+ out.unsqueeze(0)
749
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
750
+ .permute(0, 2, 1, 3)
751
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
752
+ )
753
+ return self.to_out(out)
754
+
755
+
756
+ class VideoSpatialCrossAttention(CrossAttention):
757
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0):
758
+ super().__init__(query_dim, context_dim, heads, dim_head, dropout)
759
+
760
+ def forward(self, x, context=None, mask=None):
761
+ b, c, t, h, w = x.shape
762
+ if context is not None:
763
+ context = context.repeat(t, 1, 1)
764
+ x = super.forward(spatial_attn_reshape(x), context=context) + x
765
+ return spatial_attn_reshape_back(x, b, h)
766
+
767
+
768
+ def spatial_attn_reshape(x):
769
+ return rearrange(x, "b c t h w -> (b t) (h w) c")
770
+
771
+
772
+ def spatial_attn_reshape_back(x, b, h):
773
+ return rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
774
+
775
+
776
+ def temporal_attn_reshape(x):
777
+ return rearrange(x, "b c t h w -> (b h w) t c")
778
+
779
+
780
+ def temporal_attn_reshape_back(x, b, h, w):
781
+ return rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w)
782
+
783
+
784
+ def local_spatial_temporal_attn_reshape(x, window_size):
785
+ B, C, T, H, W = x.shape
786
+ NH = H // window_size
787
+ NW = W // window_size
788
+ # x = x.view(B, C, T, NH, window_size, NW, window_size)
789
+ # tokens = x.permute(0, 1, 2, 3, 5, 4, 6).contiguous()
790
+ # tokens = tokens.view(-1, window_size, window_size, C)
791
+ x = rearrange(
792
+ x,
793
+ "b c t (nh wh) (nw ww) -> b c t nh wh nw ww",
794
+ nh=NH,
795
+ nw=NW,
796
+ wh=window_size,
797
+ ww=window_size,
798
+ ).contiguous() # # B, C, T, NH, NW, window_size, window_size
799
+ x = rearrange(
800
+ x, "b c t nh wh nw ww -> (b nh nw) (t wh ww) c"
801
+ ) # (B, NH, NW) (T, window_size, window_size) C
802
+ return x
803
+
804
+
805
+ def local_spatial_temporal_attn_reshape_back(x, window_size, b, h, w, t):
806
+ B, L, C = x.shape
807
+ NH = h // window_size
808
+ NW = w // window_size
809
+ x = rearrange(
810
+ x,
811
+ "(b nh nw) (t wh ww) c -> b c t nh wh nw ww",
812
+ b=b,
813
+ nh=NH,
814
+ nw=NW,
815
+ t=t,
816
+ wh=window_size,
817
+ ww=window_size,
818
+ )
819
+ x = rearrange(x, "b c t nh wh nw ww -> b c t (nh wh) (nw ww)")
820
+ return x
821
+
822
+
823
+ class SpatialTemporalTransformer(nn.Module):
824
+ """
825
+ Transformer block for video-like data (5D tensor).
826
+ First, project the input (aka embedding) with NO reshape.
827
+ Then apply standard transformer action.
828
+ The 5D -> 3D reshape operation will be done in the specific attention module.
829
+ """
830
+
831
+ def __init__(
832
+ self,
833
+ in_channels,
834
+ n_heads,
835
+ d_head,
836
+ depth=1,
837
+ dropout=0.0,
838
+ context_dim=None,
839
+ # Temporal stuff
840
+ temporal_length=None,
841
+ image_length=None,
842
+ use_relative_position=True,
843
+ img_video_joint_train=False,
844
+ cross_attn_on_tempoal=False,
845
+ temporal_crossattn_type="selfattn",
846
+ order="stst",
847
+ temporalcrossfirst=False,
848
+ split_stcontext=False,
849
+ temporal_context_dim=None,
850
+ **kwargs,
851
+ ):
852
+ super().__init__()
853
+
854
+ self.in_channels = in_channels
855
+ inner_dim = n_heads * d_head
856
+
857
+ self.norm = Normalize(in_channels)
858
+ self.proj_in = nn.Conv3d(
859
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
860
+ )
861
+
862
+ self.transformer_blocks = nn.ModuleList(
863
+ [
864
+ BasicTransformerBlockST(
865
+ inner_dim,
866
+ n_heads,
867
+ d_head,
868
+ dropout=dropout,
869
+ # cross attn
870
+ context_dim=context_dim,
871
+ # temporal attn
872
+ temporal_length=temporal_length,
873
+ image_length=image_length,
874
+ use_relative_position=use_relative_position,
875
+ img_video_joint_train=img_video_joint_train,
876
+ temporal_crossattn_type=temporal_crossattn_type,
877
+ order=order,
878
+ temporalcrossfirst=temporalcrossfirst,
879
+ split_stcontext=split_stcontext,
880
+ temporal_context_dim=temporal_context_dim,
881
+ **kwargs,
882
+ )
883
+ for d in range(depth)
884
+ ]
885
+ )
886
+
887
+ self.proj_out = zero_module(
888
+ nn.Conv3d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
889
+ )
890
+
891
+ def forward(self, x, context=None, temporal_context=None, **kwargs):
892
+ # note: if no context is given, cross-attention defaults to self-attention
893
+ assert x.dim() == 5, f"x shape = {x.shape}"
894
+ b, c, t, h, w = x.shape
895
+ x_in = x
896
+
897
+ x = self.norm(x)
898
+ x = self.proj_in(x)
899
+
900
+ for block in self.transformer_blocks:
901
+ x = block(x, context=context, temporal_context=temporal_context, **kwargs)
902
+
903
+ x = self.proj_out(x)
904
+ return x + x_in
905
+
906
+
907
+ class STAttentionBlock2(nn.Module):
908
+ def __init__(
909
+ self,
910
+ channels,
911
+ num_heads=1,
912
+ num_head_channels=-1,
913
+ use_checkpoint=False, # not used, only used in ResBlock
914
+ use_new_attention_order=False, # QKVAttention or QKVAttentionLegacy
915
+ temporal_length=16, # used in relative positional representation.
916
+ image_length=8, # used for image-video joint training.
917
+ use_relative_position=False, # whether use relative positional representation in temporal attention.
918
+ img_video_joint_train=False,
919
+ # norm_type="groupnorm",
920
+ attn_norm_type="group",
921
+ use_tempoal_causal_attn=False,
922
+ ):
923
+ """
924
+ version 1: guided_diffusion implemented version
925
+ version 2: remove args input argument
926
+ """
927
+ super().__init__()
928
+
929
+ if num_head_channels == -1:
930
+ self.num_heads = num_heads
931
+ else:
932
+ assert (
933
+ channels % num_head_channels == 0
934
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
935
+ self.num_heads = channels // num_head_channels
936
+ self.use_checkpoint = use_checkpoint
937
+
938
+ self.temporal_length = temporal_length
939
+ self.image_length = image_length
940
+ self.use_relative_position = use_relative_position
941
+ self.img_video_joint_train = img_video_joint_train
942
+ self.attn_norm_type = attn_norm_type
943
+ assert self.attn_norm_type in ["group", "no_norm"]
944
+ self.use_tempoal_causal_attn = use_tempoal_causal_attn
945
+
946
+ if self.attn_norm_type == "group":
947
+ self.norm_s = normalization(channels)
948
+ self.norm_t = normalization(channels)
949
+
950
+ self.qkv_s = conv_nd(1, channels, channels * 3, 1)
951
+ self.qkv_t = conv_nd(1, channels, channels * 3, 1)
952
+
953
+ if self.img_video_joint_train:
954
+ mask = torch.ones(
955
+ [1, temporal_length + image_length, temporal_length + image_length]
956
+ )
957
+ mask[:, temporal_length:, :] = 0
958
+ mask[:, :, temporal_length:] = 0
959
+ self.register_buffer("mask", mask)
960
+ else:
961
+ self.mask = None
962
+
963
+ if use_new_attention_order:
964
+ # split qkv before split heads
965
+ self.attention_s = QKVAttention(self.num_heads)
966
+ self.attention_t = QKVAttention(self.num_heads)
967
+ else:
968
+ # split heads before split qkv
969
+ self.attention_s = QKVAttentionLegacy(self.num_heads)
970
+ self.attention_t = QKVAttentionLegacy(self.num_heads)
971
+
972
+ if use_relative_position:
973
+ self.relative_position_k = RelativePosition(
974
+ num_units=channels // self.num_heads,
975
+ max_relative_position=temporal_length,
976
+ )
977
+ self.relative_position_v = RelativePosition(
978
+ num_units=channels // self.num_heads,
979
+ max_relative_position=temporal_length,
980
+ )
981
+
982
+ self.proj_out_s = zero_module(
983
+ conv_nd(1, channels, channels, 1)
984
+ ) # conv_dim, in_channels, out_channels, kernel_size
985
+ self.proj_out_t = zero_module(
986
+ conv_nd(1, channels, channels, 1)
987
+ ) # conv_dim, in_channels, out_channels, kernel_size
988
+
989
+ def forward(self, x, mask=None):
990
+ b, c, t, h, w = x.shape
991
+
992
+ # spatial
993
+ out = rearrange(x, "b c t h w -> (b t) c (h w)")
994
+ if self.attn_norm_type == "no_norm":
995
+ qkv = self.qkv_s(out)
996
+ else:
997
+ qkv = self.qkv_s(self.norm_s(out))
998
+ out = self.attention_s(qkv)
999
+ out = self.proj_out_s(out)
1000
+ out = rearrange(out, "(b t) c (h w) -> b c t h w", b=b, h=h)
1001
+ x += out
1002
+
1003
+ # temporal
1004
+ out = rearrange(x, "b c t h w -> (b h w) c t")
1005
+ if self.attn_norm_type == "no_norm":
1006
+ qkv = self.qkv_t(out)
1007
+ else:
1008
+ qkv = self.qkv_t(self.norm_t(out))
1009
+
1010
+ # relative positional embedding
1011
+ if self.use_relative_position:
1012
+ len_q = qkv.size()[-1]
1013
+ len_k, len_v = len_q, len_q
1014
+ k_rp = self.relative_position_k(len_q, len_k)
1015
+ v_rp = self.relative_position_v(len_q, len_v) # [T,T,head_dim]
1016
+ out = self.attention_t(
1017
+ qkv,
1018
+ rp=(k_rp, v_rp),
1019
+ mask=self.mask,
1020
+ use_tempoal_causal_attn=self.use_tempoal_causal_attn,
1021
+ )
1022
+ else:
1023
+ out = self.attention_t(
1024
+ qkv,
1025
+ rp=None,
1026
+ mask=self.mask,
1027
+ use_tempoal_causal_attn=self.use_tempoal_causal_attn,
1028
+ )
1029
+
1030
+ out = self.proj_out_t(out)
1031
+ out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w)
1032
+
1033
+ return x + out
1034
+
1035
+
1036
+ class QKVAttentionLegacy(nn.Module):
1037
+ """
1038
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
1039
+ """
1040
+
1041
+ def __init__(self, n_heads):
1042
+ super().__init__()
1043
+ self.n_heads = n_heads
1044
+
1045
+ def forward(self, qkv, rp=None, mask=None):
1046
+ """
1047
+ Apply QKV attention.
1048
+
1049
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
1050
+ :return: an [N x (H * C) x T] tensor after attention.
1051
+ """
1052
+ if rp is not None or mask is not None:
1053
+ raise NotImplementedError
1054
+ bs, width, length = qkv.shape
1055
+ assert width % (3 * self.n_heads) == 0
1056
+ ch = width // (3 * self.n_heads)
1057
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
1058
+ scale = 1 / math.sqrt(math.sqrt(ch))
1059
+ weight = torch.einsum(
1060
+ "bct,bcs->bts", q * scale, k * scale
1061
+ ) # More stable with f16 than dividing afterwards
1062
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
1063
+ a = torch.einsum("bts,bcs->bct", weight, v)
1064
+ return a.reshape(bs, -1, length)
1065
+
1066
+ @staticmethod
1067
+ def count_flops(model, _x, y):
1068
+ return count_flops_attn(model, _x, y)
1069
+
1070
+
1071
+ class QKVAttention(nn.Module):
1072
+ """
1073
+ A module which performs QKV attention and splits in a different order.
1074
+ """
1075
+
1076
+ def __init__(self, n_heads):
1077
+ super().__init__()
1078
+ self.n_heads = n_heads
1079
+
1080
+ def forward(self, qkv, rp=None, mask=None, use_tempoal_causal_attn=False):
1081
+ """
1082
+ Apply QKV attention.
1083
+
1084
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
1085
+ :return: an [N x (H * C) x T] tensor after attention.
1086
+ """
1087
+ bs, width, length = qkv.shape
1088
+ assert width % (3 * self.n_heads) == 0
1089
+ ch = width // (3 * self.n_heads)
1090
+ # print('qkv', qkv.size())
1091
+ qkv=qkv.contiguous()
1092
+ q, k, v = qkv.chunk(3, dim=1)
1093
+ scale = 1 / math.sqrt(math.sqrt(ch))
1094
+ # print('bs, self.n_heads, ch, length', bs, self.n_heads, ch, length)
1095
+
1096
+ weight = torch.einsum(
1097
+ "bct,bcs->bts",
1098
+ (q * scale).view(bs * self.n_heads, ch, length),
1099
+ (k * scale).view(bs * self.n_heads, ch, length),
1100
+ ) # More stable with f16 than dividing afterwards
1101
+ # weight:[b,t,s] b=bs*n_heads*T
1102
+
1103
+ if rp is not None:
1104
+ k_rp, v_rp = rp # [length, length, head_dim] [8, 8, 48]
1105
+ weight2 = torch.einsum(
1106
+ "bct,tsc->bst", (q * scale).view(bs * self.n_heads, ch, length), k_rp
1107
+ )
1108
+ weight += weight2
1109
+
1110
+ if use_tempoal_causal_attn:
1111
+ # weight = torch.tril(weight)
1112
+ assert mask is None, f"Not implemented for merging two masks!"
1113
+ mask = torch.tril(torch.ones(weight.shape))
1114
+ else:
1115
+ if mask is not None: # only keep upper-left matrix
1116
+ # process mask
1117
+ c, t, _ = weight.shape
1118
+
1119
+ if mask.shape[-1] > t:
1120
+ mask = mask[:, :t, :t]
1121
+ elif mask.shape[-1] < t: # pad ones
1122
+ mask_ = torch.zeros([c, t, t]).to(mask.device)
1123
+ t_ = mask.shape[-1]
1124
+ mask_[:, :t_, :t_] = mask
1125
+ mask = mask_
1126
+ else:
1127
+ assert (
1128
+ weight.shape[-1] == mask.shape[-1]
1129
+ ), f"weight={weight.shape}, mask={mask.shape}"
1130
+
1131
+ if mask is not None:
1132
+ INF = -1e8 # float('-inf')
1133
+ weight = weight.float().masked_fill(mask == 0, INF)
1134
+
1135
+ weight = F.softmax(weight.float(), dim=-1).type(
1136
+ weight.dtype
1137
+ ) # [256, 8, 8] [b, t, t] b=bs*n_heads*h*w,t=nframes
1138
+ # weight = F.softmax(weight, dim=-1)#[256, 8, 8] [b, t, t] b=bs*n_heads*h*w,t=nframes
1139
+ a = torch.einsum(
1140
+ "bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)
1141
+ ) # [256, 48, 8] [b, head_dim, t]
1142
+
1143
+ if rp is not None:
1144
+ a2 = torch.einsum("bts,tsc->btc", weight, v_rp).transpose(1, 2) # btc->bct
1145
+ a += a2
1146
+
1147
+ return a.reshape(bs, -1, length)
1148
+
1149
+
1150
+ def silu(x):
1151
+ # swish
1152
+ return x * torch.sigmoid(x)
1153
+
1154
+
1155
+ class SiLU(nn.Module):
1156
+ def __init__(self):
1157
+ super(SiLU, self).__init__()
1158
+
1159
+ def forward(self, x):
1160
+ return silu(x)
1161
+
1162
+
1163
+ def Normalize(in_channels, norm_type="group"):
1164
+ assert norm_type in ["group", "batch",'layer']
1165
+ if norm_type == "group":
1166
+ return torch.nn.GroupNorm(
1167
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
1168
+ )
1169
+ elif norm_type == "batch":
1170
+ return torch.nn.SyncBatchNorm(in_channels)
1171
+ elif norm_type == "layer":
1172
+ return nn.LayerNorm(in_channels)
1173
+
1174
+ class SamePadConv3d(nn.Module):
1175
+ def __init__(
1176
+ self,
1177
+ in_channels,
1178
+ out_channels,
1179
+ kernel_size,
1180
+ stride=1,
1181
+ bias=True,
1182
+ padding_type="replicate",
1183
+ ):
1184
+ super().__init__()
1185
+ if isinstance(kernel_size, int):
1186
+ kernel_size = (kernel_size,) * 3
1187
+ if isinstance(stride, int):
1188
+ stride = (stride,) * 3
1189
+
1190
+ # assumes that the input shape is divisible by stride
1191
+ total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
1192
+ pad_input = []
1193
+ for p in total_pad[::-1]: # reverse since F.pad starts from last dim
1194
+ pad_input.append((p // 2 + p % 2, p // 2))
1195
+ pad_input = sum(pad_input, tuple())
1196
+
1197
+ self.pad_input = pad_input
1198
+ self.padding_type = padding_type
1199
+
1200
+ self.conv = nn.Conv3d(
1201
+ in_channels, out_channels, kernel_size, stride=stride, padding=0, bias=bias
1202
+ )
1203
+
1204
+ def forward(self, x):
1205
+ tp=x.dtype
1206
+ x = x.float()
1207
+
1208
+ # 执行填充操作
1209
+ x_padded = F.pad(x, self.pad_input, mode=self.padding_type)
1210
+
1211
+ # 如果需要,将结果转换回 BFloat16
1212
+ x_padded = x_padded.to(tp)
1213
+
1214
+ return self.conv(x_padded)
1215
+
1216
+ class TemporalAttention(nn.Module):
1217
+ def __init__(
1218
+ self,
1219
+ channels,
1220
+ num_heads=1,
1221
+ num_head_channels=-1,
1222
+ max_temporal_length=64,
1223
+ ):
1224
+ """
1225
+ a clean multi-head temporal attention
1226
+ """
1227
+ super().__init__()
1228
+
1229
+ if num_head_channels == -1:
1230
+ self.num_heads = num_heads
1231
+ else:
1232
+ assert (
1233
+ channels % num_head_channels == 0
1234
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
1235
+ self.num_heads = channels // num_head_channels
1236
+
1237
+ self.norm = Normalize(channels)
1238
+ self.qkv = zero_module(conv_nd(1, channels, channels * 3, 1))
1239
+ self.attention = QKVAttention(self.num_heads)
1240
+ self.relative_position_k = RelativePosition(
1241
+ num_units=channels // self.num_heads,
1242
+ max_relative_position=max_temporal_length,
1243
+ )
1244
+ self.relative_position_v = RelativePosition(
1245
+ num_units=channels // self.num_heads,
1246
+ max_relative_position=max_temporal_length,
1247
+ )
1248
+ self.proj_out = zero_module(
1249
+ conv_nd(1, channels, channels, 1)
1250
+ ) # conv_dim, in_channels, out_channels, kernel_size
1251
+
1252
+ def forward(self, x, mask=None):
1253
+ b, c, t, h, w = x.shape
1254
+ out = rearrange(x, "b c t h w -> (b h w) c t")
1255
+ # torch.Size([4608, 1152, 2])1
1256
+ # torch.Size([4608, 3456, 2])2
1257
+ # torch.Size([4608, 1152, 2])3
1258
+ # torch.Size([4608, 1152, 2])4
1259
+ #print(out.shape,end='1\n')
1260
+ qkv = self.qkv(self.norm(out))
1261
+ #print(qkv.shape,end='2\n')
1262
+
1263
+ len_q = qkv.size()[-1]
1264
+ len_k, len_v = len_q, len_q
1265
+
1266
+ k_rp = self.relative_position_k(len_q, len_k)
1267
+ v_rp = self.relative_position_v(len_q, len_v) # [T,T,head_dim]
1268
+ out = self.attention(qkv, rp=(k_rp, v_rp))
1269
+ #print(out.shape,end='3\n')
1270
+ out = self.proj_out(out)
1271
+ #print(out.shape,end='4\n')
1272
+ out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w)
1273
+
1274
+ return x + out
1275
+ class TemporalAttention_lin(nn.Module):
1276
+ def __init__(
1277
+ self,
1278
+ channels,
1279
+ num_heads=8,
1280
+ num_head_channels=-1,
1281
+ max_temporal_length=64,
1282
+ ):
1283
+ """
1284
+ a clean multi-head temporal attention
1285
+ """
1286
+ super().__init__()
1287
+
1288
+ if num_head_channels == -1:
1289
+ self.num_heads = num_heads
1290
+ else:
1291
+ assert (
1292
+ channels % num_head_channels == 0
1293
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
1294
+ self.num_heads = channels // num_head_channels
1295
+
1296
+ self.norm = nn.LayerNorm(channels)
1297
+ #self.norm = Normalize(channels)
1298
+ #self.qkv = zero_module(conv_nd(1, channels, channels * 3, 1))
1299
+ self.qkv = nn.Linear(channels, channels * 3)
1300
+ self.attention = QKVAttention(self.num_heads)
1301
+ self.relative_position_k = RelativePosition(
1302
+ num_units=channels // self.num_heads,
1303
+ max_relative_position=max_temporal_length,
1304
+ )
1305
+ self.relative_position_v = RelativePosition(
1306
+ num_units=channels // self.num_heads,
1307
+ max_relative_position=max_temporal_length,
1308
+ )
1309
+ self.proj_out = nn.Linear(channels, channels)
1310
+
1311
+ def forward(self, x, mask=None):
1312
+ b, c, t, h, w = x.shape
1313
+ out = rearrange(x, "b c t h w -> (b h w) t c")
1314
+ # torch.Size([4608, 1152, 2])1
1315
+ # torch.Size([4608, 3456, 2])2
1316
+ # torch.Size([4608, 1152, 2])3
1317
+ # torch.Size([4608, 1152, 2])4
1318
+ #print(out.shape,end='1\n')
1319
+ qkv = self.qkv(self.norm(out)).transpose(-1, -2)
1320
+ #print(qkv.shape,end='2\n')
1321
+
1322
+ len_q = qkv.size()[-1]
1323
+ len_k, len_v = len_q, len_q
1324
+
1325
+ k_rp = self.relative_position_k(len_q, len_k)
1326
+ v_rp = self.relative_position_v(len_q, len_v) # [T,T,head_dim]
1327
+
1328
+ out = self.attention(qkv, rp=(k_rp, v_rp))
1329
+
1330
+ out = self.proj_out(out.transpose(-1, -2)).transpose(-1, -2)
1331
+
1332
+ #print(out.shape,end='4\n')
1333
+ out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w)
1334
+
1335
+ return x + out
1336
+
1337
+ class AttnBlock3D(nn.Module):
1338
+ def __init__(self, in_channels):
1339
+ super().__init__()
1340
+ self.in_channels = in_channels
1341
+
1342
+ self.norm = Normalize(in_channels)
1343
+ self.q = torch.nn.Conv3d(
1344
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
1345
+ )
1346
+ self.k = torch.nn.Conv3d(
1347
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
1348
+ )
1349
+ self.v = torch.nn.Conv3d(
1350
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
1351
+ )
1352
+ self.proj_out = torch.nn.Conv3d(
1353
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
1354
+ )
1355
+
1356
+ def forward(self, x):
1357
+ h_ = x
1358
+ # self.norm.to(x.device)
1359
+ # self.norm.to(x.dtype)
1360
+ h_ = self.norm(h_)
1361
+ q = self.q(h_)
1362
+ k = self.k(h_)
1363
+ v = self.v(h_)
1364
+
1365
+ b, c, t, h, w = q.shape
1366
+ # q = q.reshape(b,c,h*w) # bcl
1367
+ # q = q.permute(0,2,1) # bcl -> blc l=hw
1368
+ # k = k.reshape(b,c,h*w) # bcl
1369
+ q = rearrange(q, "b c t h w -> (b t) (h w) c") # blc
1370
+ k = rearrange(k, "b c t h w -> (b t) c (h w)") # bcl
1371
+
1372
+ w_ = torch.bmm(q, k) # b,l,l
1373
+ w_ = w_ * (int(c) ** (-0.5))
1374
+ w_ = torch.nn.functional.softmax(w_, dim=2)
1375
+
1376
+ # v = v.reshape(b,c,h*w)
1377
+ v = rearrange(v, "b c t h w -> (b t) c (h w)") # bcl
1378
+
1379
+ # attend to values
1380
+ w_ = w_.permute(0, 2, 1) # bll
1381
+ h_ = torch.bmm(v, w_) # bcl
1382
+
1383
+ # h_ = h_.reshape(b,c,h,w)
1384
+ h_ = rearrange(h_, "(b t) c (h w) -> b c t h w", b=b, h=h)
1385
+
1386
+ h_ = self.proj_out(h_)
1387
+
1388
+ return x + h_
1389
+
1390
+ class MultiHeadAttention3D(nn.Module):
1391
+ def __init__(self, in_channels, num_heads=8):
1392
+ super().__init__()
1393
+ self.in_channels = in_channels
1394
+ self.num_heads = num_heads
1395
+ self.head_dim = in_channels // num_heads
1396
+
1397
+ assert self.head_dim * num_heads == in_channels, "in_channels must be divisible by num_heads"
1398
+
1399
+ self.norm = nn.LayerNorm(in_channels)
1400
+ self.q_linear = nn.Linear(in_channels, in_channels)
1401
+ self.k_linear = nn.Linear(in_channels, in_channels)
1402
+ self.v_linear = nn.Linear(in_channels, in_channels)
1403
+ self.proj_out = nn.Linear(in_channels, in_channels)
1404
+
1405
+ def forward(self, x):
1406
+ b, c, t, h, w = x.shape
1407
+ #print(x.shape)
1408
+ # Normalize and reshape input
1409
+ h_ = rearrange(x, "b c t h w -> (b t) (h w) c")
1410
+ h_ = self.norm(h_)
1411
+
1412
+ # Linear projections
1413
+ q = self.q_linear(h_)
1414
+ k = self.k_linear(h_)
1415
+ v = self.v_linear(h_)
1416
+
1417
+ # Reshape to multi-head
1418
+ q = rearrange(q, "b l (h d) -> b h l d", h=self.num_heads)
1419
+ k = rearrange(k, "b l (h d) -> b h l d", h=self.num_heads)
1420
+ v = rearrange(v, "b l (h d) -> b h l d", h=self.num_heads)
1421
+
1422
+ # Scaled Dot-Product Attention
1423
+ scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
1424
+ attn = F.softmax(scores, dim=-1)
1425
+
1426
+ # Apply attention to values
1427
+ out = torch.matmul(attn, v)
1428
+ out = rearrange(out, "b h l d -> b l (h d)")
1429
+
1430
+ # Project back to original dimension
1431
+ out = self.proj_out(out)
1432
+
1433
+ # Reshape back to original shape
1434
+ out = rearrange(out, "(b t) (h w) c -> b c t h w", b=b, h=h, t=t)
1435
+ #print(out.shape)
1436
+ return x + out
1437
+
1438
 
1439
  class SiglipAE(nn.Module):
1440
  def __init__(self):
 
1466
 
1467
  x=self.encoder(x)
1468
  return x
 
 
 
 
 
 
 
1469
 
 
sae_utils.py DELETED
@@ -1,302 +0,0 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
- from transformers.activations import ACT2FN
5
- from .attention_temporal_videoae import *
6
- from einops import rearrange, reduce, repeat
7
-
8
- try:
9
- import xformers
10
- import xformers.ops as xops
11
-
12
- XFORMERS_IS_AVAILBLE = True
13
- except:
14
- XFORMERS_IS_AVAILBLE = False
15
-
16
- def silu(x):
17
- # swish
18
- return x * torch.sigmoid(x)
19
-
20
-
21
- class SiLU(nn.Module):
22
- def __init__(self):
23
- super(SiLU, self).__init__()
24
-
25
- def forward(self, x):
26
- return silu(x)
27
-
28
-
29
- def Normalize(in_channels, norm_type="group"):
30
- assert norm_type in ["group", "batch",'layer']
31
- if norm_type == "group":
32
- return torch.nn.GroupNorm(
33
- num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
34
- )
35
- elif norm_type == "batch":
36
- return torch.nn.SyncBatchNorm(in_channels)
37
- elif norm_type == "layer":
38
- return nn.LayerNorm(in_channels)
39
-
40
- class SamePadConv3d(nn.Module):
41
- def __init__(
42
- self,
43
- in_channels,
44
- out_channels,
45
- kernel_size,
46
- stride=1,
47
- bias=True,
48
- padding_type="replicate",
49
- ):
50
- super().__init__()
51
- if isinstance(kernel_size, int):
52
- kernel_size = (kernel_size,) * 3
53
- if isinstance(stride, int):
54
- stride = (stride,) * 3
55
-
56
- # assumes that the input shape is divisible by stride
57
- total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
58
- pad_input = []
59
- for p in total_pad[::-1]: # reverse since F.pad starts from last dim
60
- pad_input.append((p // 2 + p % 2, p // 2))
61
- pad_input = sum(pad_input, tuple())
62
-
63
- self.pad_input = pad_input
64
- self.padding_type = padding_type
65
-
66
- self.conv = nn.Conv3d(
67
- in_channels, out_channels, kernel_size, stride=stride, padding=0, bias=bias
68
- )
69
-
70
- def forward(self, x):
71
- tp=x.dtype
72
- x = x.float()
73
-
74
- # 执行填充操作
75
- x_padded = F.pad(x, self.pad_input, mode=self.padding_type)
76
-
77
- # 如果需要,将结果转换回 BFloat16
78
- x_padded = x_padded.to(tp)
79
-
80
- return self.conv(x_padded)
81
-
82
- class TemporalAttention(nn.Module):
83
- def __init__(
84
- self,
85
- channels,
86
- num_heads=1,
87
- num_head_channels=-1,
88
- max_temporal_length=64,
89
- ):
90
- """
91
- a clean multi-head temporal attention
92
- """
93
- super().__init__()
94
-
95
- if num_head_channels == -1:
96
- self.num_heads = num_heads
97
- else:
98
- assert (
99
- channels % num_head_channels == 0
100
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
101
- self.num_heads = channels // num_head_channels
102
-
103
- self.norm = Normalize(channels)
104
- self.qkv = zero_module(conv_nd(1, channels, channels * 3, 1))
105
- self.attention = QKVAttention(self.num_heads)
106
- self.relative_position_k = RelativePosition(
107
- num_units=channels // self.num_heads,
108
- max_relative_position=max_temporal_length,
109
- )
110
- self.relative_position_v = RelativePosition(
111
- num_units=channels // self.num_heads,
112
- max_relative_position=max_temporal_length,
113
- )
114
- self.proj_out = zero_module(
115
- conv_nd(1, channels, channels, 1)
116
- ) # conv_dim, in_channels, out_channels, kernel_size
117
-
118
- def forward(self, x, mask=None):
119
- b, c, t, h, w = x.shape
120
- out = rearrange(x, "b c t h w -> (b h w) c t")
121
- # torch.Size([4608, 1152, 2])1
122
- # torch.Size([4608, 3456, 2])2
123
- # torch.Size([4608, 1152, 2])3
124
- # torch.Size([4608, 1152, 2])4
125
- #print(out.shape,end='1\n')
126
- qkv = self.qkv(self.norm(out))
127
- #print(qkv.shape,end='2\n')
128
-
129
- len_q = qkv.size()[-1]
130
- len_k, len_v = len_q, len_q
131
-
132
- k_rp = self.relative_position_k(len_q, len_k)
133
- v_rp = self.relative_position_v(len_q, len_v) # [T,T,head_dim]
134
- out = self.attention(qkv, rp=(k_rp, v_rp))
135
- #print(out.shape,end='3\n')
136
- out = self.proj_out(out)
137
- #print(out.shape,end='4\n')
138
- out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w)
139
-
140
- return x + out
141
- class TemporalAttention_lin(nn.Module):
142
- def __init__(
143
- self,
144
- channels,
145
- num_heads=8,
146
- num_head_channels=-1,
147
- max_temporal_length=64,
148
- ):
149
- """
150
- a clean multi-head temporal attention
151
- """
152
- super().__init__()
153
-
154
- if num_head_channels == -1:
155
- self.num_heads = num_heads
156
- else:
157
- assert (
158
- channels % num_head_channels == 0
159
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
160
- self.num_heads = channels // num_head_channels
161
-
162
- self.norm = nn.LayerNorm(channels)
163
- #self.norm = Normalize(channels)
164
- #self.qkv = zero_module(conv_nd(1, channels, channels * 3, 1))
165
- self.qkv = nn.Linear(channels, channels * 3)
166
- self.attention = QKVAttention(self.num_heads)
167
- self.relative_position_k = RelativePosition(
168
- num_units=channels // self.num_heads,
169
- max_relative_position=max_temporal_length,
170
- )
171
- self.relative_position_v = RelativePosition(
172
- num_units=channels // self.num_heads,
173
- max_relative_position=max_temporal_length,
174
- )
175
- self.proj_out = nn.Linear(channels, channels)
176
-
177
- def forward(self, x, mask=None):
178
- b, c, t, h, w = x.shape
179
- out = rearrange(x, "b c t h w -> (b h w) t c")
180
- # torch.Size([4608, 1152, 2])1
181
- # torch.Size([4608, 3456, 2])2
182
- # torch.Size([4608, 1152, 2])3
183
- # torch.Size([4608, 1152, 2])4
184
- #print(out.shape,end='1\n')
185
- qkv = self.qkv(self.norm(out)).transpose(-1, -2)
186
- #print(qkv.shape,end='2\n')
187
-
188
- len_q = qkv.size()[-1]
189
- len_k, len_v = len_q, len_q
190
-
191
- k_rp = self.relative_position_k(len_q, len_k)
192
- v_rp = self.relative_position_v(len_q, len_v) # [T,T,head_dim]
193
-
194
- out = self.attention(qkv, rp=(k_rp, v_rp))
195
-
196
- out = self.proj_out(out.transpose(-1, -2)).transpose(-1, -2)
197
-
198
- #print(out.shape,end='4\n')
199
- out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w)
200
-
201
- return x + out
202
-
203
- class AttnBlock3D(nn.Module):
204
- def __init__(self, in_channels):
205
- super().__init__()
206
- self.in_channels = in_channels
207
-
208
- self.norm = Normalize(in_channels)
209
- self.q = torch.nn.Conv3d(
210
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
211
- )
212
- self.k = torch.nn.Conv3d(
213
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
214
- )
215
- self.v = torch.nn.Conv3d(
216
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
217
- )
218
- self.proj_out = torch.nn.Conv3d(
219
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
220
- )
221
-
222
- def forward(self, x):
223
- h_ = x
224
- # self.norm.to(x.device)
225
- # self.norm.to(x.dtype)
226
- h_ = self.norm(h_)
227
- q = self.q(h_)
228
- k = self.k(h_)
229
- v = self.v(h_)
230
-
231
- b, c, t, h, w = q.shape
232
- # q = q.reshape(b,c,h*w) # bcl
233
- # q = q.permute(0,2,1) # bcl -> blc l=hw
234
- # k = k.reshape(b,c,h*w) # bcl
235
- q = rearrange(q, "b c t h w -> (b t) (h w) c") # blc
236
- k = rearrange(k, "b c t h w -> (b t) c (h w)") # bcl
237
-
238
- w_ = torch.bmm(q, k) # b,l,l
239
- w_ = w_ * (int(c) ** (-0.5))
240
- w_ = torch.nn.functional.softmax(w_, dim=2)
241
-
242
- # v = v.reshape(b,c,h*w)
243
- v = rearrange(v, "b c t h w -> (b t) c (h w)") # bcl
244
-
245
- # attend to values
246
- w_ = w_.permute(0, 2, 1) # bll
247
- h_ = torch.bmm(v, w_) # bcl
248
-
249
- # h_ = h_.reshape(b,c,h,w)
250
- h_ = rearrange(h_, "(b t) c (h w) -> b c t h w", b=b, h=h)
251
-
252
- h_ = self.proj_out(h_)
253
-
254
- return x + h_
255
-
256
- class MultiHeadAttention3D(nn.Module):
257
- def __init__(self, in_channels, num_heads=8):
258
- super().__init__()
259
- self.in_channels = in_channels
260
- self.num_heads = num_heads
261
- self.head_dim = in_channels // num_heads
262
-
263
- assert self.head_dim * num_heads == in_channels, "in_channels must be divisible by num_heads"
264
-
265
- self.norm = nn.LayerNorm(in_channels)
266
- self.q_linear = nn.Linear(in_channels, in_channels)
267
- self.k_linear = nn.Linear(in_channels, in_channels)
268
- self.v_linear = nn.Linear(in_channels, in_channels)
269
- self.proj_out = nn.Linear(in_channels, in_channels)
270
-
271
- def forward(self, x):
272
- b, c, t, h, w = x.shape
273
- #print(x.shape)
274
- # Normalize and reshape input
275
- h_ = rearrange(x, "b c t h w -> (b t) (h w) c")
276
- h_ = self.norm(h_)
277
-
278
- # Linear projections
279
- q = self.q_linear(h_)
280
- k = self.k_linear(h_)
281
- v = self.v_linear(h_)
282
-
283
- # Reshape to multi-head
284
- q = rearrange(q, "b l (h d) -> b h l d", h=self.num_heads)
285
- k = rearrange(k, "b l (h d) -> b h l d", h=self.num_heads)
286
- v = rearrange(v, "b l (h d) -> b h l d", h=self.num_heads)
287
-
288
- # Scaled Dot-Product Attention
289
- scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
290
- attn = F.softmax(scores, dim=-1)
291
-
292
- # Apply attention to values
293
- out = torch.matmul(attn, v)
294
- out = rearrange(out, "b h l d -> b l (h d)")
295
-
296
- # Project back to original dimension
297
- out = self.proj_out(out)
298
-
299
- # Reshape back to original shape
300
- out = rearrange(out, "(b t) (h w) c -> b c t h w", b=b, h=h, t=t)
301
- #print(out.shape)
302
- return x + out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
siglip_encoder.py DELETED
@@ -1,154 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
- from typing import Optional, Tuple, Union, Dict
5
- from PIL import Image
6
- from functools import partial, reduce
7
- from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
8
-
9
- from .base_encoder import BaseVisionTower
10
- import torch.distributed as dist
11
- # --data_path /share/shuyan/video_traindata/anno/\{cinepine_order\}.json \
12
- # --image_folder /share/shuyan/video_traindata/Bunny-v1_0-data/finetune/images \
13
- # --video_folder /share/shuyan/video_traindata \
14
- def rank0_print(*args):
15
- if dist.is_initialized():
16
- if dist.get_rank() == 0:
17
- print(f"Rank {dist.get_rank()}: ", *args)
18
- else:
19
- print(*args)
20
-
21
-
22
- from transformers.image_processing_utils import BatchFeature, get_size_dict
23
- from transformers.image_transforms import (
24
- convert_to_rgb,
25
- normalize,
26
- rescale,
27
- resize,
28
- to_channel_dimension_format,
29
- )
30
- from transformers.image_utils import (
31
- ChannelDimension,
32
- PILImageResampling,
33
- to_numpy_array,
34
- )
35
- class SigLipImageProcessor:
36
- def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST):
37
- crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
38
- crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
39
-
40
- self.image_mean = image_mean
41
- self.image_std = image_std
42
- self.size = size
43
- self.resample = resample
44
- self.rescale_factor = rescale_factor
45
- self.data_format = data_format
46
- self.crop_size = crop_size
47
-
48
- def preprocess(self, images, return_tensors):
49
- if isinstance(images, Image.Image):
50
- images = [images]
51
- else:
52
- # to adapt video data
53
- images = [to_numpy_array(image) for image in images]
54
- assert isinstance(images, list)
55
-
56
- transforms = [
57
- convert_to_rgb,
58
- to_numpy_array,
59
- partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
60
- partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
61
- partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
62
- partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
63
- ]
64
-
65
- images = reduce(lambda x, f: [*map(f, x)], transforms, images)
66
-
67
- data = {"pixel_values": images}
68
-
69
- return BatchFeature(data=data, tensor_type=return_tensors)
70
-
71
- class SigLipVisionTower(BaseVisionTower):
72
- def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False):
73
- super(SigLipVisionTower, self).__init__(vision_tower_name, vision_tower_cfg, delay_load)
74
-
75
- # model_path = "google/siglip-so400m-patch14-384"
76
- # base_model_name, res, interp = model_path, 384, 576
77
- # self.vision_tower_name = base_model_name
78
- self.vision_tower_name, res, interp = vision_tower_name, 384, 576
79
- self._image_size = res if res is not None else 512
80
- self.unfreeze_mm_vision_tower = getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False)
81
-
82
- if not delay_load:
83
- rank0_print(f"Loading vision tower: {vision_tower_name}")
84
- self.load_model()
85
- elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False):
86
- # TODO: better detector is needed.
87
- rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
88
- self.load_model()
89
- elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts:
90
- rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
91
- self.load_model()
92
- else:
93
- self.cfg_only = self.config
94
-
95
- def load_model(self, device_map=None):
96
- self.vision_model = "siglip"
97
- # clip_model, processor = create_model_from_pretrained(self.vision_tower_name)
98
- print(self.vision_tower_name)
99
- self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
100
-
101
- # self.vision_tower = clip_model.visual.trunk
102
- self.vision_tower.output_tokens = True
103
-
104
- self._hidden_size = self.vision_tower.config.hidden_size
105
-
106
- self.image_processor = SigLipImageProcessor()
107
-
108
- del self.vision_tower.vision_model.encoder.layers[-1:]
109
- self.vision_tower.vision_model.head = nn.Identity()
110
-
111
- self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
112
-
113
- self.is_loaded = True
114
-
115
- def _forward(self, images):
116
- with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
117
- image_features = self.vision_tower.forward(
118
- images.to(device=self.device, dtype=self.dtype),
119
- output_hidden_states=True,
120
- ).hidden_states[-1]
121
- return image_features
122
- @property
123
- def dummy_feature(self):
124
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
125
-
126
- @property
127
- def dtype(self):
128
- for p in self.vision_tower.parameters():
129
- return p.dtype
130
-
131
- @property
132
- def device(self):
133
- for p in self.vision_tower.parameters():
134
- return p.device
135
-
136
- @property
137
- def hidden_size(self):
138
- return self.config.hidden_size
139
-
140
- @property
141
- def num_patches(self):
142
- return (336 // 14) ** 2
143
-
144
- @property
145
- def num_patches_per_side(self):
146
- #return self.config.image_size // self.config.patch_size
147
- return 336//14
148
- #return 27
149
- # return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
150
-
151
- @property
152
- def image_size(self):
153
- return 384
154
- #return self.config.image_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils_encoder.py DELETED
@@ -1,296 +0,0 @@
1
- import importlib
2
- import numpy as np
3
- import cv2, os
4
- import torch
5
- import torch.distributed as dist
6
-
7
-
8
- def count_params(model, verbose=False):
9
- total_params = sum(p.numel() for p in model.parameters())
10
- if verbose:
11
- print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
12
- return total_params
13
-
14
-
15
- def check_istarget(name, para_list):
16
- """
17
- name: full name of source para
18
- para_list: partial name of target para
19
- """
20
- istarget = False
21
- for para in para_list:
22
- if para in name:
23
- return True
24
- return istarget
25
-
26
-
27
- def instantiate_from_config(config):
28
- if not "target" in config:
29
- if config == "__is_first_stage__":
30
- return None
31
- elif config == "__is_unconditional__":
32
- return None
33
- raise KeyError("Expected key `target` to instantiate.")
34
-
35
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
36
-
37
-
38
- def get_obj_from_str(string, reload=False):
39
- module, cls = string.rsplit(".", 1)
40
- if reload:
41
- module_imp = importlib.import_module(module)
42
- importlib.reload(module_imp)
43
- return getattr(importlib.import_module(module, package=None), cls)
44
-
45
-
46
- def load_npz_from_dir(data_dir):
47
- data = [
48
- np.load(os.path.join(data_dir, data_name))["arr_0"]
49
- for data_name in os.listdir(data_dir)
50
- ]
51
- data = np.concatenate(data, axis=0)
52
- return data
53
-
54
-
55
- def load_npz_from_paths(data_paths):
56
- data = [np.load(data_path)["arr_0"] for data_path in data_paths]
57
- data = np.concatenate(data, axis=0)
58
- return data
59
-
60
-
61
- def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None):
62
- h, w = image.shape[:2]
63
- if resize_short_edge is not None:
64
- k = resize_short_edge / min(h, w)
65
- else:
66
- k = max_resolution / (h * w)
67
- k = k**0.5
68
- h = int(np.round(h * k / 64)) * 64
69
- w = int(np.round(w * k / 64)) * 64
70
- image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
71
- return image
72
-
73
-
74
- def setup_dist(args):
75
- if dist.is_initialized():
76
- return
77
- torch.cuda.set_device(args.local_rank)
78
- torch.distributed.init_process_group("nccl", init_method="env://")
79
-
80
-
81
- # adopted from
82
- # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
83
- # and
84
- # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
85
- # and
86
- # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
87
- #
88
- # thanks!
89
-
90
- import torch.nn as nn
91
- import math
92
- from inspect import isfunction
93
- import torch
94
- from torch import nn
95
- import torch.distributed as dist
96
-
97
-
98
- def gather_data(data, return_np=True):
99
- """gather data from multiple processes to one list"""
100
- data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
101
- dist.all_gather(data_list, data) # gather not supported with NCCL
102
- if return_np:
103
- data_list = [data.cpu().numpy() for data in data_list]
104
- return data_list
105
-
106
-
107
- def autocast(f):
108
- def do_autocast(*args, **kwargs):
109
- with torch.cuda.amp.autocast(
110
- enabled=True,
111
- dtype=torch.get_autocast_gpu_dtype(),
112
- cache_enabled=torch.is_autocast_cache_enabled(),
113
- ):
114
- return f(*args, **kwargs)
115
-
116
- return do_autocast
117
-
118
-
119
- def extract_into_tensor(a, t, x_shape):
120
- b, *_ = t.shape
121
- out = a.gather(-1, t)
122
- return out.reshape(b, *((1,) * (len(x_shape) - 1)))
123
-
124
-
125
- def noise_like(shape, device, repeat=False):
126
- repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
127
- shape[0], *((1,) * (len(shape) - 1))
128
- )
129
- noise = lambda: torch.randn(shape, device=device)
130
- return repeat_noise() if repeat else noise()
131
-
132
-
133
- def default(val, d):
134
- if exists(val):
135
- return val
136
- return d() if isfunction(d) else d
137
-
138
-
139
- def exists(val):
140
- return val is not None
141
-
142
-
143
- def identity(*args, **kwargs):
144
- return nn.Identity()
145
-
146
-
147
- def uniq(arr):
148
- return {el: True for el in arr}.keys()
149
-
150
-
151
- def mean_flat(tensor):
152
- """
153
- Take the mean over all non-batch dimensions.
154
- """
155
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
156
-
157
-
158
- def ismap(x):
159
- if not isinstance(x, torch.Tensor):
160
- return False
161
- return (len(x.shape) == 4) and (x.shape[1] > 3)
162
-
163
-
164
- def isimage(x):
165
- if not isinstance(x, torch.Tensor):
166
- return False
167
- return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
168
-
169
-
170
- def max_neg_value(t):
171
- return -torch.finfo(t.dtype).max
172
-
173
-
174
- def shape_to_str(x):
175
- shape_str = "x".join([str(x) for x in x.shape])
176
- return shape_str
177
-
178
-
179
- def init_(tensor):
180
- dim = tensor.shape[-1]
181
- std = 1 / math.sqrt(dim)
182
- tensor.uniform_(-std, std)
183
- return tensor
184
-
185
-
186
- # ckpt = torch.utils.checkpoint.checkpoint
187
-
188
-
189
- # def checkpoint(func, inputs, params, flag):
190
- # """
191
- # Evaluate a function without caching intermediate activations, allowing for
192
- # reduced memory at the expense of extra compute in the backward pass.
193
- # :param func: the function to evaluate.
194
- # :param inputs: the argument sequence to pass to `func`.
195
- # :param params: a sequence of parameters `func` depends on but does not
196
- # explicitly take as arguments.
197
- # :param flag: if False, disable gradient checkpointing.
198
- # """
199
- # if flag:
200
- # return ckpt(func, *inputs)
201
- # else:
202
- # return func(*inputs)
203
-
204
-
205
- def disabled_train(self, mode=True):
206
- """Overwrite model.train with this function to make sure train/eval mode
207
- does not change anymore."""
208
- return self
209
-
210
-
211
- def zero_module(module):
212
- """
213
- Zero out the parameters of a module and return it.
214
- """
215
- for p in module.parameters():
216
- p.detach().zero_()
217
- return module
218
-
219
-
220
- def scale_module(module, scale):
221
- """
222
- Scale the parameters of a module and return it.
223
- """
224
- for p in module.parameters():
225
- p.detach().mul_(scale)
226
- return module
227
-
228
-
229
- def conv_nd(dims, *args, **kwargs):
230
- """
231
- Create a 1D, 2D, or 3D convolution module.
232
- """
233
- if dims == 1:
234
- return nn.Conv1d(*args, **kwargs)
235
- elif dims == 2:
236
- return nn.Conv2d(*args, **kwargs)
237
- elif dims == 3:
238
- return nn.Conv3d(*args, **kwargs)
239
- raise ValueError(f"unsupported dimensions: {dims}")
240
-
241
-
242
- def linear(*args, **kwargs):
243
- """
244
- Create a linear module.
245
- """
246
- return nn.Linear(*args, **kwargs)
247
-
248
-
249
- def avg_pool_nd(dims, *args, **kwargs):
250
- """
251
- Create a 1D, 2D, or 3D average pooling module.
252
- """
253
- if dims == 1:
254
- return nn.AvgPool1d(*args, **kwargs)
255
- elif dims == 2:
256
- return nn.AvgPool2d(*args, **kwargs)
257
- elif dims == 3:
258
- return nn.AvgPool3d(*args, **kwargs)
259
- raise ValueError(f"unsupported dimensions: {dims}")
260
-
261
-
262
- def nonlinearity(type="silu"):
263
- if type == "silu":
264
- return nn.SiLU()
265
- elif type == "leaky_relu":
266
- return nn.LeakyReLU()
267
-
268
-
269
- class GroupNormSpecific(nn.GroupNorm):
270
- def forward(self, x):
271
- if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
272
- return super().forward(x).type(x.dtype)
273
- else:
274
- return super().forward(x.float()).type(x.dtype)
275
-
276
-
277
- def normalization(channels, num_groups=32):
278
- """
279
- Make a standard normalization layer.
280
- :param channels: number of input channels.
281
- :return: an nn.Module for normalization.
282
- """
283
- return GroupNormSpecific(num_groups, channels)
284
-
285
-
286
- class HybridConditioner(nn.Module):
287
-
288
- def __init__(self, c_concat_config, c_crossattn_config):
289
- super().__init__()
290
- self.concat_conditioner = instantiate_from_config(c_concat_config)
291
- self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
292
-
293
- def forward(self, c_concat, c_crossattn):
294
- c_concat = self.concat_conditioner(c_concat)
295
- c_crossattn = self.crossattn_conditioner(c_crossattn)
296
- return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multimodal_projector/builder.py → vision_projector_builder.py RENAMED
@@ -1,8 +1,36 @@
1
  import torch
2
  import torch.nn as nn
3
  import re
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- from .pooler_projector import PoolerProjector
6
 
7
 
8
  class IdentityMap(nn.Module):
 
1
  import torch
2
  import torch.nn as nn
3
  import re
4
+ import math
5
+ from transformers.models.clip.modeling_clip import CLIPVisionModel
6
+
7
+
8
+ class PoolerProjector(nn.Module):
9
+ def __init__(self, config, vision_cfg):
10
+ super().__init__()
11
+ self._config = config
12
+ self.hw = vision_cfg.image_size // vision_cfg.patch_size
13
+
14
+ self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2)
15
+
16
+ self.proj = nn.Sequential(
17
+ nn.GELU(),
18
+ nn.Linear(config.hidden_size, config.hidden_size),
19
+ )
20
+
21
+ def forward(self, x, *args, **kwargs):
22
+ height = width = self.hw
23
+ assert height * width == x.shape[1]
24
+ x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2)
25
+ x = self.conv_pool(x)
26
+ x = x.flatten(2).transpose(1, 2)
27
+ x = self.proj(x)
28
+ return x
29
+
30
+ @property
31
+ def config(self):
32
+ return {"mm_projector_type": "pooler"}
33
 
 
34
 
35
 
36
  class IdentityMap(nn.Module):
multimodal_resampler/spatial_pool.py → vision_resampler_builder.py RENAMED
@@ -43,3 +43,26 @@ class SpatialPool(nn.Module):
43
  @property
44
  def hidden_size(self):
45
  return self.out_channels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  @property
44
  def hidden_size(self):
45
  return self.out_channels
46
+
47
+
48
+
49
+ class IdentityMap(torch.nn.Module):
50
+ def __init__(self):
51
+ super().__init__()
52
+
53
+ def forward(self, x, *args, **kwargs):
54
+ return x
55
+
56
+ @property
57
+ def config(self):
58
+ return {"mm_resampler_type": None}
59
+
60
+
61
+ def build_vision_resampler(model_args, delay_load=False, **kwargs):
62
+ resampler_type = getattr(model_args, "mm_resampler_type", None)
63
+ if resampler_type == "spatial_pool":
64
+ return SpatialPool(model_args, **kwargs)
65
+ elif resampler_type is None:
66
+ return IdentityMap()
67
+
68
+ raise ValueError(f"Unknown resampler type: {resampler_type}")
multimodal_encoder/siglip_encoder.py → vision_tower_builder.py RENAMED
@@ -1,24 +1,13 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
  from typing import Optional, Tuple, Union, Dict
5
  from PIL import Image
6
  from functools import partial, reduce
7
  from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
8
-
9
- from .base_encoder import BaseVisionTower
10
  import torch.distributed as dist
11
- # --data_path /share/shuyan/video_traindata/anno/\{cinepine_order\}.json \
12
- # --image_folder /share/shuyan/video_traindata/Bunny-v1_0-data/finetune/images \
13
- # --video_folder /share/shuyan/video_traindata \
14
- def rank0_print(*args):
15
- if dist.is_initialized():
16
- if dist.get_rank() == 0:
17
- print(f"Rank {dist.get_rank()}: ", *args)
18
- else:
19
- print(*args)
20
-
21
-
22
  from transformers.image_processing_utils import BatchFeature, get_size_dict
23
  from transformers.image_transforms import (
24
  convert_to_rgb,
@@ -32,6 +21,78 @@ from transformers.image_utils import (
32
  PILImageResampling,
33
  to_numpy_array,
34
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  class SigLipImageProcessor:
36
  def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST):
37
  crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
@@ -151,4 +212,18 @@ class SigLipVisionTower(BaseVisionTower):
151
  @property
152
  def image_size(self):
153
  return 384
154
- #return self.config.image_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
 
 
2
  from typing import Optional, Tuple, Union, Dict
3
  from PIL import Image
4
  from functools import partial, reduce
5
  from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
 
 
6
  import torch.distributed as dist
7
+ from abc import ABC, abstractmethod
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
 
 
 
 
 
 
 
11
  from transformers.image_processing_utils import BatchFeature, get_size_dict
12
  from transformers.image_transforms import (
13
  convert_to_rgb,
 
21
  PILImageResampling,
22
  to_numpy_array,
23
  )
24
+
25
+ def rank0_print(*args):
26
+ if dist.is_initialized():
27
+ if dist.get_rank() == 0:
28
+ print(f"Rank {dist.get_rank()}: ", *args)
29
+ else:
30
+ print(*args)
31
+
32
+
33
+ class BaseVisionTower(nn.Module):
34
+ def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False):
35
+ super().__init__()
36
+
37
+ self.is_loaded = False
38
+
39
+ self.vision_tower_name = vision_tower_name
40
+ self.delay_load = delay_load
41
+
42
+ @abstractmethod
43
+ def load_model(self, device_map=None):
44
+ raise NotImplementedError("Subclasses must implement load_model")
45
+
46
+ @abstractmethod
47
+ def _forward(self, images):
48
+ raise NotImplementedError("Subclasses must implement forward")
49
+
50
+ def forward(self, images):
51
+ if type(images) is list:
52
+ image_features = [self._forward(image.unsqueeze(0)) for image in images]
53
+ else:
54
+ image_features = self._forward(images)
55
+
56
+ return image_features
57
+
58
+ @property
59
+ def dummy_feature(self):
60
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
61
+
62
+ @property
63
+ def dtype(self):
64
+ # Dynamically infer the dtype from the first parameter, if not explicitly specified
65
+ if hasattr(self.vision_tower, "dtype"):
66
+ return self.vision_tower.dtype
67
+ else:
68
+ params = list(self.vision_tower.parameters())
69
+ return (
70
+ params[0].dtype if len(params) > 0 else torch.float32
71
+ ) # Default to torch.float32 if no parameters
72
+
73
+ @property
74
+ def device(self):
75
+ # Dynamically infer the device from the first parameter, if not explicitly specified
76
+ if hasattr(self.vision_tower, "device"):
77
+ return self.vision_tower.device
78
+ else:
79
+ params = list(self.vision_tower.parameters())
80
+ return (
81
+ params[0].device if len(params) > 0 else torch.device("cpu")
82
+ ) # Default to CPU if no parameters
83
+ @property
84
+ def config(self):
85
+ if self.is_loaded:
86
+ return self.vision_tower.config
87
+ else:
88
+ return self.cfg_only
89
+ @property
90
+ def hidden_size(self):
91
+ try:
92
+ return self.config.hidden_size
93
+ except:
94
+ return self._hidden_size
95
+
96
  class SigLipImageProcessor:
97
  def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST):
98
  crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
 
212
  @property
213
  def image_size(self):
214
  return 384
215
+
216
+ def build_vision_tower(vision_tower_cfg, **kwargs):
217
+
218
+ vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
219
+ is_absolute_path_exists = os.path.exists(vision_tower)
220
+ use_s2 = getattr(vision_tower_cfg, "s2", False)
221
+
222
+ #print(getattr(vision_tower_cfg, "vision_tower", None))
223
+ return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
224
+ if getattr(vision_tower_cfg, "vision_tower", None) and "siglip" in getattr(vision_tower_cfg, "vision_tower", None).lower():
225
+ #print('*************\n')
226
+ return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
227
+
228
+
229
+ raise ValueError(f"Unknown vision tower: {vision_tower}")