BAAI
/

Video-XL-2 / sae.py
3v324v23's picture
fix bug
5644dea
import torch
import torch.nn as nn
import pdb
import math
from transformers.activations import ACT2FN
from einops import rearrange, reduce, repeat
from inspect import isfunction
import math
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
import importlib
import numpy as np
import cv2, os
import torch.distributed as dist
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
return total_params
def check_istarget(name, para_list):
"""
name: full name of source para
para_list: partial name of target para
"""
istarget = False
for para in para_list:
if para in name:
return True
return istarget
def instantiate_from_config(config):
if not "target" in config:
if config == "__is_first_stage__":
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def load_npz_from_dir(data_dir):
data = [
np.load(os.path.join(data_dir, data_name))["arr_0"]
for data_name in os.listdir(data_dir)
]
data = np.concatenate(data, axis=0)
return data
def load_npz_from_paths(data_paths):
data = [np.load(data_path)["arr_0"] for data_path in data_paths]
data = np.concatenate(data, axis=0)
return data
def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None):
h, w = image.shape[:2]
if resize_short_edge is not None:
k = resize_short_edge / min(h, w)
else:
k = max_resolution / (h * w)
k = k**0.5
h = int(np.round(h * k / 64)) * 64
w = int(np.round(w * k / 64)) * 64
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
return image
def setup_dist(args):
if dist.is_initialized():
return
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group("nccl", init_method="env://")
# adopted from
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# and
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
# and
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
#
# thanks!
import torch.nn as nn
import math
from inspect import isfunction
import torch
from torch import nn
import torch.distributed as dist
def gather_data(data, return_np=True):
"""gather data from multiple processes to one list"""
data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
dist.all_gather(data_list, data) # gather not supported with NCCL
if return_np:
data_list = [data.cpu().numpy() for data in data_list]
return data_list
def autocast(f):
def do_autocast(*args, **kwargs):
with torch.cuda.amp.autocast(
enabled=True,
dtype=torch.get_autocast_gpu_dtype(),
cache_enabled=torch.is_autocast_cache_enabled(),
):
return f(*args, **kwargs)
return do_autocast
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
shape[0], *((1,) * (len(shape) - 1))
)
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def exists(val):
return val is not None
def identity(*args, **kwargs):
return nn.Identity()
def uniq(arr):
return {el: True for el in arr}.keys()
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def shape_to_str(x):
shape_str = "x".join([str(x) for x in x.shape])
return shape_str
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def nonlinearity(type="silu"):
if type == "silu":
return nn.SiLU()
elif type == "leaky_relu":
return nn.LeakyReLU()
class GroupNormSpecific(nn.GroupNorm):
def forward(self, x):
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
return super().forward(x).type(x.dtype)
else:
return super().forward(x.float()).type(x.dtype)
def normalization(channels, num_groups=32):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNormSpecific(num_groups, channels)
class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config):
super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat)
c_crossattn = self.crossattn_conditioner(c_crossattn)
return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
def exists(val):
return val is not None
def uniq(arr):
return {el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
class RelativePosition(nn.Module):
"""https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py"""
def __init__(self, num_units, max_relative_position):
super().__init__()
self.num_units = num_units
self.max_relative_position = max_relative_position
self.embeddings_table = nn.Parameter(
torch.Tensor(max_relative_position * 2 + 1, num_units)
)
nn.init.xavier_uniform_(self.embeddings_table)
def forward(self, length_q, length_k):
device = self.embeddings_table.device
range_vec_q = torch.arange(length_q, device=device)
range_vec_k = torch.arange(length_k, device=device)
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
distance_mat_clipped = torch.clamp(
distance_mat, -self.max_relative_position, self.max_relative_position
)
final_mat = distance_mat_clipped + self.max_relative_position
# final_mat = torch.LongTensor(final_mat).to(self.embeddings_table.device)
# final_mat = torch.tensor(final_mat, device=self.embeddings_table.device, dtype=torch.long)
final_mat = final_mat.long()
embeddings = self.embeddings_table[final_mat]
return embeddings
class TemporalCrossAttention(nn.Module):
def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
temporal_length=None, # For relative positional representation and image-video joint training.
image_length=None, # For image-video joint training.
use_relative_position=False, # whether use relative positional representation in temporal attention.
img_video_joint_train=False, # For image-video joint training.
use_tempoal_causal_attn=False,
bidirectional_causal_attn=False,
tempoal_attn_type=None,
joint_train_mode="same_batch",
**kwargs,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.context_dim = context_dim
self.scale = dim_head**-0.5
self.heads = heads
self.temporal_length = temporal_length
self.use_relative_position = use_relative_position
self.img_video_joint_train = img_video_joint_train
self.bidirectional_causal_attn = bidirectional_causal_attn
self.joint_train_mode = joint_train_mode
assert joint_train_mode in ["same_batch", "diff_batch"]
self.tempoal_attn_type = tempoal_attn_type
if bidirectional_causal_attn:
assert use_tempoal_causal_attn
if tempoal_attn_type:
assert tempoal_attn_type in ["sparse_causal", "sparse_causal_first"]
assert not use_tempoal_causal_attn
assert not (
img_video_joint_train and (self.joint_train_mode == "same_batch")
)
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
assert not (
img_video_joint_train
and (self.joint_train_mode == "same_batch")
and use_tempoal_causal_attn
)
if img_video_joint_train:
if self.joint_train_mode == "same_batch":
mask = torch.ones(
[1, temporal_length + image_length, temporal_length + image_length]
)
# mask[:, image_length:, :] = 0
# mask[:, :, image_length:] = 0
mask[:, temporal_length:, :] = 0
mask[:, :, temporal_length:] = 0
self.mask = mask
else:
self.mask = None
elif use_tempoal_causal_attn:
# normal causal attn
self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length]))
elif tempoal_attn_type == "sparse_causal":
# all frames interact with only the `prev` & self frame
mask1 = torch.tril(
torch.ones([1, temporal_length, temporal_length])
).bool() # true indicates keeping
mask2 = torch.zeros(
[1, temporal_length, temporal_length]
) # initialize to same shape with mask1
mask2[:, 2:temporal_length, : temporal_length - 2] = torch.tril(
torch.ones([1, temporal_length - 2, temporal_length - 2])
)
mask2 = (1 - mask2).bool() # false indicates masking
self.mask = mask1 & mask2
elif tempoal_attn_type == "sparse_causal_first":
# all frames interact with only the `first` & self frame
mask1 = torch.tril(
torch.ones([1, temporal_length, temporal_length])
).bool() # true indicates keeping
mask2 = torch.zeros([1, temporal_length, temporal_length])
mask2[:, 2:temporal_length, 1 : temporal_length - 1] = torch.tril(
torch.ones([1, temporal_length - 2, temporal_length - 2])
)
mask2 = (1 - mask2).bool() # false indicates masking
self.mask = mask1 & mask2
else:
self.mask = None
if use_relative_position:
assert temporal_length is not None
self.relative_position_k = RelativePosition(
num_units=dim_head, max_relative_position=temporal_length
)
self.relative_position_v = RelativePosition(
num_units=dim_head, max_relative_position=temporal_length
)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
nn.init.constant_(self.to_q.weight, 0)
nn.init.constant_(self.to_k.weight, 0)
nn.init.constant_(self.to_v.weight, 0)
nn.init.constant_(self.to_out[0].weight, 0)
nn.init.constant_(self.to_out[0].bias, 0)
def forward(self, x, context=None, mask=None):
# if context is None:
# print(f'[Temp Attn] x={x.shape},context=None')
# else:
# print(f'[Temp Attn] x={x.shape},context={context.shape}')
nh = self.heads
out = x
q = self.to_q(out)
# if context is not None:
# print(f'temporal context 1 ={context.shape}')
# print(f'x={x.shape}')
context = default(context, x)
# print(f'temporal context 2 ={context.shape}')
k = self.to_k(context)
v = self.to_v(context)
# print(f'q ={q.shape},k={k.shape}')
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=nh), (q, k, v))
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
if self.use_relative_position:
len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
k2 = self.relative_position_k(len_q, len_k)
sim2 = einsum("b t d, t s d -> b t s", q, k2) * self.scale # TODO check
sim += sim2
# print('mask',mask)
if exists(self.mask):
if mask is None:
mask = self.mask.to(sim.device)
else:
mask = self.mask.to(sim.device).bool() & mask # .to(sim.device)
else:
mask = mask
# if self.img_video_joint_train:
# # process mask (make mask same shape with sim)
# c, h, w = mask.shape
# c, t, s = sim.shape
# # assert(h == w and t == s),f"mask={mask.shape}, sim={sim.shape}, h={h}, w={w}, t={t}, s={s}"
# if h > t:
# mask = mask[:, :t, :]
# elif h < t: # pad zeros to mask (no attention) only initial mask =1 area compute weights
# mask_ = torch.zeros([c,t,w]).to(mask.device)
# mask_[:, :h, :] = mask
# mask = mask_
# c, h, w = mask.shape
# if w > s:
# mask = mask[:, :, :s]
# elif w < s: # pad zeros to mask
# mask_ = torch.zeros([c,h,s]).to(mask.device)
# mask_[:, :, :w] = mask
# mask = mask_
# max_neg_value = -torch.finfo(sim.dtype).max
# sim = sim.float().masked_fill(mask == 0, max_neg_value)
if mask is not None:
max_neg_value = -1e9
sim = sim + (1 - mask.float()) * max_neg_value # 1=masking,0=no masking
# print('sim after masking: ', sim)
# if torch.isnan(sim).any() or torch.isinf(sim).any() or (not sim.any()):
# print(f'sim [after masking], isnan={torch.isnan(sim).any()}, isinf={torch.isinf(sim).any()}, allzero={not sim.any()}')
attn = sim.softmax(dim=-1)
# print('attn after softmax: ', attn)
# if torch.isnan(attn).any() or torch.isinf(attn).any() or (not attn.any()):
# print(f'attn [after softmax], isnan={torch.isnan(attn).any()}, isinf={torch.isinf(attn).any()}, allzero={not attn.any()}')
# attn = torch.where(torch.isnan(attn), torch.full_like(attn,0), attn)
# if torch.isinf(attn.detach()).any():
# import pdb;pdb.set_trace()
# if torch.isnan(attn.detach()).any():
# import pdb;pdb.set_trace()
out = einsum("b i j, b j d -> b i d", attn, v)
if self.bidirectional_causal_attn:
mask_reverse = torch.triu(
torch.ones(
[1, self.temporal_length, self.temporal_length], device=sim.device
)
)
sim_reverse = sim.float().masked_fill(mask_reverse == 0, max_neg_value)
attn_reverse = sim_reverse.softmax(dim=-1)
out_reverse = einsum("b i j, b j d -> b i d", attn_reverse, v)
out += out_reverse
if self.use_relative_position:
v2 = self.relative_position_v(len_q, len_v)
out2 = einsum("b t s, t s d -> b t d", attn, v2) # TODO check
out += out2 # TODO check:先add还是先merge head?先计算rpr,on split head之后的数据,然后再merge。
out = rearrange(out, "(b h) n d -> b n (h d)", h=nh) # merge head
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_)
return x + h_
class CrossAttention(nn.Module):
def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
sa_shared_kv=False,
shared_type="only_first",
**kwargs,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.sa_shared_kv = sa_shared_kv
assert shared_type in [
"only_first",
"all_frames",
"first_and_prev",
"only_prev",
"full",
"causal",
"full_qkv",
]
self.shared_type = shared_type
self.scale = dim_head**-0.5
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None):
h = self.heads
b = x.shape[0]
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if self.sa_shared_kv:
if self.shared_type == "only_first":
k, v = map(
lambda xx: rearrange(xx[0].unsqueeze(0), "b n c -> (b n) c")
.unsqueeze(0)
.repeat(b, 1, 1),
(k, v),
)
else:
raise NotImplementedError
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, "b j -> (b h) () j", h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
def efficient_forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op
)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
class VideoSpatialCrossAttention(CrossAttention):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0):
super().__init__(query_dim, context_dim, heads, dim_head, dropout)
def forward(self, x, context=None, mask=None):
b, c, t, h, w = x.shape
if context is not None:
context = context.repeat(t, 1, 1)
x = super.forward(spatial_attn_reshape(x), context=context) + x
return spatial_attn_reshape_back(x, b, h)
def spatial_attn_reshape(x):
return rearrange(x, "b c t h w -> (b t) (h w) c")
def spatial_attn_reshape_back(x, b, h):
return rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
def temporal_attn_reshape(x):
return rearrange(x, "b c t h w -> (b h w) t c")
def temporal_attn_reshape_back(x, b, h, w):
return rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w)
def local_spatial_temporal_attn_reshape(x, window_size):
B, C, T, H, W = x.shape
NH = H // window_size
NW = W // window_size
# x = x.view(B, C, T, NH, window_size, NW, window_size)
# tokens = x.permute(0, 1, 2, 3, 5, 4, 6).contiguous()
# tokens = tokens.view(-1, window_size, window_size, C)
x = rearrange(
x,
"b c t (nh wh) (nw ww) -> b c t nh wh nw ww",
nh=NH,
nw=NW,
wh=window_size,
ww=window_size,
).contiguous() # # B, C, T, NH, NW, window_size, window_size
x = rearrange(
x, "b c t nh wh nw ww -> (b nh nw) (t wh ww) c"
) # (B, NH, NW) (T, window_size, window_size) C
return x
def local_spatial_temporal_attn_reshape_back(x, window_size, b, h, w, t):
B, L, C = x.shape
NH = h // window_size
NW = w // window_size
x = rearrange(
x,
"(b nh nw) (t wh ww) c -> b c t nh wh nw ww",
b=b,
nh=NH,
nw=NW,
t=t,
wh=window_size,
ww=window_size,
)
x = rearrange(x, "b c t nh wh nw ww -> b c t (nh wh) (nw ww)")
return x
class SpatialTemporalTransformer(nn.Module):
"""
Transformer block for video-like data (5D tensor).
First, project the input (aka embedding) with NO reshape.
Then apply standard transformer action.
The 5D -> 3D reshape operation will be done in the specific attention module.
"""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
# Temporal stuff
temporal_length=None,
image_length=None,
use_relative_position=True,
img_video_joint_train=False,
cross_attn_on_tempoal=False,
temporal_crossattn_type="selfattn",
order="stst",
temporalcrossfirst=False,
split_stcontext=False,
temporal_context_dim=None,
**kwargs,
):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv3d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlockST(
inner_dim,
n_heads,
d_head,
dropout=dropout,
# cross attn
context_dim=context_dim,
# temporal attn
temporal_length=temporal_length,
image_length=image_length,
use_relative_position=use_relative_position,
img_video_joint_train=img_video_joint_train,
temporal_crossattn_type=temporal_crossattn_type,
order=order,
temporalcrossfirst=temporalcrossfirst,
split_stcontext=split_stcontext,
temporal_context_dim=temporal_context_dim,
**kwargs,
)
for d in range(depth)
]
)
self.proj_out = zero_module(
nn.Conv3d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
def forward(self, x, context=None, temporal_context=None, **kwargs):
# note: if no context is given, cross-attention defaults to self-attention
assert x.dim() == 5, f"x shape = {x.shape}"
b, c, t, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
for block in self.transformer_blocks:
x = block(x, context=context, temporal_context=temporal_context, **kwargs)
x = self.proj_out(x)
return x + x_in
class STAttentionBlock2(nn.Module):
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False, # not used, only used in ResBlock
use_new_attention_order=False, # QKVAttention or QKVAttentionLegacy
temporal_length=16, # used in relative positional representation.
image_length=8, # used for image-video joint training.
use_relative_position=False, # whether use relative positional representation in temporal attention.
img_video_joint_train=False,
# norm_type="groupnorm",
attn_norm_type="group",
use_tempoal_causal_attn=False,
):
"""
version 1: guided_diffusion implemented version
version 2: remove args input argument
"""
super().__init__()
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.temporal_length = temporal_length
self.image_length = image_length
self.use_relative_position = use_relative_position
self.img_video_joint_train = img_video_joint_train
self.attn_norm_type = attn_norm_type
assert self.attn_norm_type in ["group", "no_norm"]
self.use_tempoal_causal_attn = use_tempoal_causal_attn
if self.attn_norm_type == "group":
self.norm_s = normalization(channels)
self.norm_t = normalization(channels)
self.qkv_s = conv_nd(1, channels, channels * 3, 1)
self.qkv_t = conv_nd(1, channels, channels * 3, 1)
if self.img_video_joint_train:
mask = torch.ones(
[1, temporal_length + image_length, temporal_length + image_length]
)
mask[:, temporal_length:, :] = 0
mask[:, :, temporal_length:] = 0
self.register_buffer("mask", mask)
else:
self.mask = None
if use_new_attention_order:
# split qkv before split heads
self.attention_s = QKVAttention(self.num_heads)
self.attention_t = QKVAttention(self.num_heads)
else:
# split heads before split qkv
self.attention_s = QKVAttentionLegacy(self.num_heads)
self.attention_t = QKVAttentionLegacy(self.num_heads)
if use_relative_position:
self.relative_position_k = RelativePosition(
num_units=channels // self.num_heads,
max_relative_position=temporal_length,
)
self.relative_position_v = RelativePosition(
num_units=channels // self.num_heads,
max_relative_position=temporal_length,
)
self.proj_out_s = zero_module(
conv_nd(1, channels, channels, 1)
) # conv_dim, in_channels, out_channels, kernel_size
self.proj_out_t = zero_module(
conv_nd(1, channels, channels, 1)
) # conv_dim, in_channels, out_channels, kernel_size
def forward(self, x, mask=None):
b, c, t, h, w = x.shape
# spatial
out = rearrange(x, "b c t h w -> (b t) c (h w)")
if self.attn_norm_type == "no_norm":
qkv = self.qkv_s(out)
else:
qkv = self.qkv_s(self.norm_s(out))
out = self.attention_s(qkv)
out = self.proj_out_s(out)
out = rearrange(out, "(b t) c (h w) -> b c t h w", b=b, h=h)
x += out
# temporal
out = rearrange(x, "b c t h w -> (b h w) c t")
if self.attn_norm_type == "no_norm":
qkv = self.qkv_t(out)
else:
qkv = self.qkv_t(self.norm_t(out))
# relative positional embedding
if self.use_relative_position:
len_q = qkv.size()[-1]
len_k, len_v = len_q, len_q
k_rp = self.relative_position_k(len_q, len_k)
v_rp = self.relative_position_v(len_q, len_v) # [T,T,head_dim]
out = self.attention_t(
qkv,
rp=(k_rp, v_rp),
mask=self.mask,
use_tempoal_causal_attn=self.use_tempoal_causal_attn,
)
else:
out = self.attention_t(
qkv,
rp=None,
mask=self.mask,
use_tempoal_causal_attn=self.use_tempoal_causal_attn,
)
out = self.proj_out_t(out)
out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w)
return x + out
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, rp=None, mask=None):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
if rp is not None or mask is not None:
raise NotImplementedError
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, rp=None, mask=None, use_tempoal_causal_attn=False):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
# print('qkv', qkv.size())
qkv=qkv.contiguous()
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
# print('bs, self.n_heads, ch, length', bs, self.n_heads, ch, length)
weight = torch.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
# weight:[b,t,s] b=bs*n_heads*T
if rp is not None:
k_rp, v_rp = rp # [length, length, head_dim] [8, 8, 48]
weight2 = torch.einsum(
"bct,tsc->bst", (q * scale).view(bs * self.n_heads, ch, length), k_rp
)
weight += weight2
if use_tempoal_causal_attn:
# weight = torch.tril(weight)
assert mask is None, f"Not implemented for merging two masks!"
mask = torch.tril(torch.ones(weight.shape))
else:
if mask is not None: # only keep upper-left matrix
# process mask
c, t, _ = weight.shape
if mask.shape[-1] > t:
mask = mask[:, :t, :t]
elif mask.shape[-1] < t: # pad ones
mask_ = torch.zeros([c, t, t]).to(mask.device)
t_ = mask.shape[-1]
mask_[:, :t_, :t_] = mask
mask = mask_
else:
assert (
weight.shape[-1] == mask.shape[-1]
), f"weight={weight.shape}, mask={mask.shape}"
if mask is not None:
INF = -1e8 # float('-inf')
weight = weight.float().masked_fill(mask == 0, INF)
weight = F.softmax(weight.float(), dim=-1).type(
weight.dtype
) # [256, 8, 8] [b, t, t] b=bs*n_heads*h*w,t=nframes
# weight = F.softmax(weight, dim=-1)#[256, 8, 8] [b, t, t] b=bs*n_heads*h*w,t=nframes
a = torch.einsum(
"bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)
) # [256, 48, 8] [b, head_dim, t]
if rp is not None:
a2 = torch.einsum("bts,tsc->btc", weight, v_rp).transpose(1, 2) # btc->bct
a += a2
return a.reshape(bs, -1, length)
def silu(x):
# swish
return x * torch.sigmoid(x)
class SiLU(nn.Module):
def __init__(self):
super(SiLU, self).__init__()
def forward(self, x):
return silu(x)
def Normalize(in_channels, norm_type="group"):
assert norm_type in ["group", "batch",'layer']
if norm_type == "group":
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
elif norm_type == "batch":
return torch.nn.SyncBatchNorm(in_channels)
elif norm_type == "layer":
return nn.LayerNorm(in_channels)
class SamePadConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
bias=True,
padding_type="replicate",
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3
if isinstance(stride, int):
stride = (stride,) * 3
# assumes that the input shape is divisible by stride
total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
pad_input = []
for p in total_pad[::-1]: # reverse since F.pad starts from last dim
pad_input.append((p // 2 + p % 2, p // 2))
pad_input = sum(pad_input, tuple())
self.pad_input = pad_input
self.padding_type = padding_type
self.conv = nn.Conv3d(
in_channels, out_channels, kernel_size, stride=stride, padding=0, bias=bias
)
def forward(self, x):
tp=x.dtype
x = x.float()
# 执行填充操作
x_padded = F.pad(x, self.pad_input, mode=self.padding_type)
# 如果需要,将结果转换回 BFloat16
x_padded = x_padded.to(tp)
return self.conv(x_padded)
class TemporalAttention(nn.Module):
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
max_temporal_length=64,
):
"""
a clean multi-head temporal attention
"""
super().__init__()
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = Normalize(channels)
self.qkv = zero_module(conv_nd(1, channels, channels * 3, 1))
self.attention = QKVAttention(self.num_heads)
self.relative_position_k = RelativePosition(
num_units=channels // self.num_heads,
max_relative_position=max_temporal_length,
)
self.relative_position_v = RelativePosition(
num_units=channels // self.num_heads,
max_relative_position=max_temporal_length,
)
self.proj_out = zero_module(
conv_nd(1, channels, channels, 1)
) # conv_dim, in_channels, out_channels, kernel_size
def forward(self, x, mask=None):
b, c, t, h, w = x.shape
out = rearrange(x, "b c t h w -> (b h w) c t")
# torch.Size([4608, 1152, 2])1
# torch.Size([4608, 3456, 2])2
# torch.Size([4608, 1152, 2])3
# torch.Size([4608, 1152, 2])4
#print(out.shape,end='1\n')
qkv = self.qkv(self.norm(out))
#print(qkv.shape,end='2\n')
len_q = qkv.size()[-1]
len_k, len_v = len_q, len_q
k_rp = self.relative_position_k(len_q, len_k)
v_rp = self.relative_position_v(len_q, len_v) # [T,T,head_dim]
out = self.attention(qkv, rp=(k_rp, v_rp))
#print(out.shape,end='3\n')
out = self.proj_out(out)
#print(out.shape,end='4\n')
out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w)
return x + out
class TemporalAttention_lin(nn.Module):
def __init__(
self,
channels,
num_heads=8,
num_head_channels=-1,
max_temporal_length=64,
):
"""
a clean multi-head temporal attention
"""
super().__init__()
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = nn.LayerNorm(channels)
#self.norm = Normalize(channels)
#self.qkv = zero_module(conv_nd(1, channels, channels * 3, 1))
self.qkv = nn.Linear(channels, channels * 3)
self.attention = QKVAttention(self.num_heads)
self.relative_position_k = RelativePosition(
num_units=channels // self.num_heads,
max_relative_position=max_temporal_length,
)
self.relative_position_v = RelativePosition(
num_units=channels // self.num_heads,
max_relative_position=max_temporal_length,
)
self.proj_out = nn.Linear(channels, channels)
def forward(self, x, mask=None):
b, c, t, h, w = x.shape
out = rearrange(x, "b c t h w -> (b h w) t c")
# torch.Size([4608, 1152, 2])1
# torch.Size([4608, 3456, 2])2
# torch.Size([4608, 1152, 2])3
# torch.Size([4608, 1152, 2])4
#print(out.shape,end='1\n')
qkv = self.qkv(self.norm(out)).transpose(-1, -2)
#print(qkv.shape,end='2\n')
len_q = qkv.size()[-1]
len_k, len_v = len_q, len_q
k_rp = self.relative_position_k(len_q, len_k)
v_rp = self.relative_position_v(len_q, len_v) # [T,T,head_dim]
out = self.attention(qkv, rp=(k_rp, v_rp))
out = self.proj_out(out.transpose(-1, -2)).transpose(-1, -2)
#print(out.shape,end='4\n')
out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w)
return x + out
class AttnBlock3D(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv3d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv3d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv3d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv3d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
# self.norm.to(x.device)
# self.norm.to(x.dtype)
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, t, h, w = q.shape
# q = q.reshape(b,c,h*w) # bcl
# q = q.permute(0,2,1) # bcl -> blc l=hw
# k = k.reshape(b,c,h*w) # bcl
q = rearrange(q, "b c t h w -> (b t) (h w) c") # blc
k = rearrange(k, "b c t h w -> (b t) c (h w)") # bcl
w_ = torch.bmm(q, k) # b,l,l
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# v = v.reshape(b,c,h*w)
v = rearrange(v, "b c t h w -> (b t) c (h w)") # bcl
# attend to values
w_ = w_.permute(0, 2, 1) # bll
h_ = torch.bmm(v, w_) # bcl
# h_ = h_.reshape(b,c,h,w)
h_ = rearrange(h_, "(b t) c (h w) -> b c t h w", b=b, h=h)
h_ = self.proj_out(h_)
return x + h_
class MultiHeadAttention3D(nn.Module):
def __init__(self, in_channels, num_heads=8):
super().__init__()
self.in_channels = in_channels
self.num_heads = num_heads
self.head_dim = in_channels // num_heads
assert self.head_dim * num_heads == in_channels, "in_channels must be divisible by num_heads"
self.norm = nn.LayerNorm(in_channels)
self.q_linear = nn.Linear(in_channels, in_channels)
self.k_linear = nn.Linear(in_channels, in_channels)
self.v_linear = nn.Linear(in_channels, in_channels)
self.proj_out = nn.Linear(in_channels, in_channels)
def forward(self, x):
b, c, t, h, w = x.shape
#print(x.shape)
# Normalize and reshape input
h_ = rearrange(x, "b c t h w -> (b t) (h w) c")
h_ = self.norm(h_)
# Linear projections
q = self.q_linear(h_)
k = self.k_linear(h_)
v = self.v_linear(h_)
# Reshape to multi-head
q = rearrange(q, "b l (h d) -> b h l d", h=self.num_heads)
k = rearrange(k, "b l (h d) -> b h l d", h=self.num_heads)
v = rearrange(v, "b l (h d) -> b h l d", h=self.num_heads)
# Scaled Dot-Product Attention
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn = F.softmax(scores, dim=-1)
# Apply attention to values
out = torch.matmul(attn, v)
out = rearrange(out, "b h l d -> b l (h d)")
# Project back to original dimension
out = self.proj_out(out)
# Reshape back to original shape
out = rearrange(out, "(b t) (h w) c -> b c t h w", b=b, h=h, t=t)
#print(out.shape)
return x + out
class SiglipAE(nn.Module):
def __init__(self):
super().__init__()
temporal_stride=2
norm_type = "group"
self.temporal_encoding = nn.Parameter(torch.randn((4,1152)))
#self.vision_tower=SigLipVisionTower('google/siglip-so400m-patch14-384')
self.encoder=nn.Sequential(
AttnBlock3D(1152),
TemporalAttention(1152),
SamePadConv3d(1152,1152,kernel_size=3,stride=(temporal_stride, 1, 1),padding_type="replicate"),
AttnBlock3D(1152),
TemporalAttention(1152),
SamePadConv3d(1152,1152,kernel_size=3,stride=(temporal_stride, 1, 1),padding_type="replicate"),
)
def forward(self, x):
b_,c_,t_,h_,w_=x.shape
temporal_encoding = self.temporal_encoding.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
temporal_encoding = temporal_encoding.expand(b_, -1, -1, h_, w_) # (B, T, C, H, W)
temporal_encoding = temporal_encoding.permute(0, 2, 1, 3, 4) # (B, C, T, H, W)
x = x + temporal_encoding
x=self.encoder(x)
return x