|
import torch |
|
import torch.nn as nn |
|
import pdb |
|
import math |
|
from transformers.activations import ACT2FN |
|
from einops import rearrange, reduce, repeat |
|
from inspect import isfunction |
|
import math |
|
import torch.nn.functional as F |
|
from torch import nn, einsum |
|
from einops import rearrange, repeat |
|
from typing import Optional, Any |
|
|
|
try: |
|
import xformers |
|
import xformers.ops |
|
|
|
XFORMERS_IS_AVAILBLE = True |
|
except: |
|
XFORMERS_IS_AVAILBLE = False |
|
|
|
import importlib |
|
import numpy as np |
|
import cv2, os |
|
import torch.distributed as dist |
|
|
|
|
|
def count_params(model, verbose=False): |
|
total_params = sum(p.numel() for p in model.parameters()) |
|
if verbose: |
|
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") |
|
return total_params |
|
|
|
|
|
def check_istarget(name, para_list): |
|
""" |
|
name: full name of source para |
|
para_list: partial name of target para |
|
""" |
|
istarget = False |
|
for para in para_list: |
|
if para in name: |
|
return True |
|
return istarget |
|
|
|
|
|
def instantiate_from_config(config): |
|
if not "target" in config: |
|
if config == "__is_first_stage__": |
|
return None |
|
elif config == "__is_unconditional__": |
|
return None |
|
raise KeyError("Expected key `target` to instantiate.") |
|
|
|
return get_obj_from_str(config["target"])(**config.get("params", dict())) |
|
|
|
|
|
def get_obj_from_str(string, reload=False): |
|
module, cls = string.rsplit(".", 1) |
|
if reload: |
|
module_imp = importlib.import_module(module) |
|
importlib.reload(module_imp) |
|
return getattr(importlib.import_module(module, package=None), cls) |
|
|
|
|
|
def load_npz_from_dir(data_dir): |
|
data = [ |
|
np.load(os.path.join(data_dir, data_name))["arr_0"] |
|
for data_name in os.listdir(data_dir) |
|
] |
|
data = np.concatenate(data, axis=0) |
|
return data |
|
|
|
|
|
def load_npz_from_paths(data_paths): |
|
data = [np.load(data_path)["arr_0"] for data_path in data_paths] |
|
data = np.concatenate(data, axis=0) |
|
return data |
|
|
|
|
|
def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None): |
|
h, w = image.shape[:2] |
|
if resize_short_edge is not None: |
|
k = resize_short_edge / min(h, w) |
|
else: |
|
k = max_resolution / (h * w) |
|
k = k**0.5 |
|
h = int(np.round(h * k / 64)) * 64 |
|
w = int(np.round(w * k / 64)) * 64 |
|
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) |
|
return image |
|
|
|
|
|
def setup_dist(args): |
|
if dist.is_initialized(): |
|
return |
|
torch.cuda.set_device(args.local_rank) |
|
torch.distributed.init_process_group("nccl", init_method="env://") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
import math |
|
from inspect import isfunction |
|
import torch |
|
from torch import nn |
|
import torch.distributed as dist |
|
|
|
|
|
def gather_data(data, return_np=True): |
|
"""gather data from multiple processes to one list""" |
|
data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())] |
|
dist.all_gather(data_list, data) |
|
if return_np: |
|
data_list = [data.cpu().numpy() for data in data_list] |
|
return data_list |
|
|
|
|
|
def autocast(f): |
|
def do_autocast(*args, **kwargs): |
|
with torch.cuda.amp.autocast( |
|
enabled=True, |
|
dtype=torch.get_autocast_gpu_dtype(), |
|
cache_enabled=torch.is_autocast_cache_enabled(), |
|
): |
|
return f(*args, **kwargs) |
|
|
|
return do_autocast |
|
|
|
|
|
def extract_into_tensor(a, t, x_shape): |
|
b, *_ = t.shape |
|
out = a.gather(-1, t) |
|
return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
|
|
|
|
|
def noise_like(shape, device, repeat=False): |
|
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( |
|
shape[0], *((1,) * (len(shape) - 1)) |
|
) |
|
noise = lambda: torch.randn(shape, device=device) |
|
return repeat_noise() if repeat else noise() |
|
|
|
|
|
def default(val, d): |
|
if exists(val): |
|
return val |
|
return d() if isfunction(d) else d |
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
def identity(*args, **kwargs): |
|
return nn.Identity() |
|
|
|
|
|
def uniq(arr): |
|
return {el: True for el in arr}.keys() |
|
|
|
|
|
def mean_flat(tensor): |
|
""" |
|
Take the mean over all non-batch dimensions. |
|
""" |
|
return tensor.mean(dim=list(range(1, len(tensor.shape)))) |
|
|
|
|
|
def ismap(x): |
|
if not isinstance(x, torch.Tensor): |
|
return False |
|
return (len(x.shape) == 4) and (x.shape[1] > 3) |
|
|
|
|
|
def isimage(x): |
|
if not isinstance(x, torch.Tensor): |
|
return False |
|
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) |
|
|
|
|
|
def max_neg_value(t): |
|
return -torch.finfo(t.dtype).max |
|
|
|
|
|
def shape_to_str(x): |
|
shape_str = "x".join([str(x) for x in x.shape]) |
|
return shape_str |
|
|
|
|
|
def init_(tensor): |
|
dim = tensor.shape[-1] |
|
std = 1 / math.sqrt(dim) |
|
tensor.uniform_(-std, std) |
|
return tensor |
|
|
|
|
|
|
|
def disabled_train(self, mode=True): |
|
"""Overwrite model.train with this function to make sure train/eval mode |
|
does not change anymore.""" |
|
return self |
|
|
|
|
|
def zero_module(module): |
|
""" |
|
Zero out the parameters of a module and return it. |
|
""" |
|
for p in module.parameters(): |
|
p.detach().zero_() |
|
return module |
|
|
|
|
|
def scale_module(module, scale): |
|
""" |
|
Scale the parameters of a module and return it. |
|
""" |
|
for p in module.parameters(): |
|
p.detach().mul_(scale) |
|
return module |
|
|
|
|
|
def conv_nd(dims, *args, **kwargs): |
|
""" |
|
Create a 1D, 2D, or 3D convolution module. |
|
""" |
|
if dims == 1: |
|
return nn.Conv1d(*args, **kwargs) |
|
elif dims == 2: |
|
return nn.Conv2d(*args, **kwargs) |
|
elif dims == 3: |
|
return nn.Conv3d(*args, **kwargs) |
|
raise ValueError(f"unsupported dimensions: {dims}") |
|
|
|
|
|
def linear(*args, **kwargs): |
|
""" |
|
Create a linear module. |
|
""" |
|
return nn.Linear(*args, **kwargs) |
|
|
|
|
|
def avg_pool_nd(dims, *args, **kwargs): |
|
""" |
|
Create a 1D, 2D, or 3D average pooling module. |
|
""" |
|
if dims == 1: |
|
return nn.AvgPool1d(*args, **kwargs) |
|
elif dims == 2: |
|
return nn.AvgPool2d(*args, **kwargs) |
|
elif dims == 3: |
|
return nn.AvgPool3d(*args, **kwargs) |
|
raise ValueError(f"unsupported dimensions: {dims}") |
|
|
|
|
|
def nonlinearity(type="silu"): |
|
if type == "silu": |
|
return nn.SiLU() |
|
elif type == "leaky_relu": |
|
return nn.LeakyReLU() |
|
|
|
|
|
class GroupNormSpecific(nn.GroupNorm): |
|
def forward(self, x): |
|
if x.dtype == torch.float16 or x.dtype == torch.bfloat16: |
|
return super().forward(x).type(x.dtype) |
|
else: |
|
return super().forward(x.float()).type(x.dtype) |
|
|
|
|
|
def normalization(channels, num_groups=32): |
|
""" |
|
Make a standard normalization layer. |
|
:param channels: number of input channels. |
|
:return: an nn.Module for normalization. |
|
""" |
|
return GroupNormSpecific(num_groups, channels) |
|
|
|
|
|
class HybridConditioner(nn.Module): |
|
|
|
def __init__(self, c_concat_config, c_crossattn_config): |
|
super().__init__() |
|
self.concat_conditioner = instantiate_from_config(c_concat_config) |
|
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) |
|
|
|
def forward(self, c_concat, c_crossattn): |
|
c_concat = self.concat_conditioner(c_concat) |
|
c_crossattn = self.crossattn_conditioner(c_crossattn) |
|
return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} |
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
def uniq(arr): |
|
return {el: True for el in arr}.keys() |
|
|
|
|
|
def default(val, d): |
|
if exists(val): |
|
return val |
|
return d() if isfunction(d) else d |
|
|
|
|
|
def max_neg_value(t): |
|
return -torch.finfo(t.dtype).max |
|
|
|
|
|
def init_(tensor): |
|
dim = tensor.shape[-1] |
|
std = 1 / math.sqrt(dim) |
|
tensor.uniform_(-std, std) |
|
return tensor |
|
|
|
|
|
|
|
class GEGLU(nn.Module): |
|
def __init__(self, dim_in, dim_out): |
|
super().__init__() |
|
self.proj = nn.Linear(dim_in, dim_out * 2) |
|
|
|
def forward(self, x): |
|
x, gate = self.proj(x).chunk(2, dim=-1) |
|
return x * F.gelu(gate) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): |
|
super().__init__() |
|
inner_dim = int(dim * mult) |
|
dim_out = default(dim_out, dim) |
|
project_in = ( |
|
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) |
|
if not glu |
|
else GEGLU(dim, inner_dim) |
|
) |
|
|
|
self.net = nn.Sequential( |
|
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) |
|
) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
|
|
def zero_module(module): |
|
""" |
|
Zero out the parameters of a module and return it. |
|
""" |
|
for p in module.parameters(): |
|
p.detach().zero_() |
|
return module |
|
|
|
|
|
def Normalize(in_channels, num_groups=32): |
|
return torch.nn.GroupNorm( |
|
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True |
|
) |
|
|
|
|
|
class RelativePosition(nn.Module): |
|
"""https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py""" |
|
|
|
def __init__(self, num_units, max_relative_position): |
|
super().__init__() |
|
self.num_units = num_units |
|
self.max_relative_position = max_relative_position |
|
self.embeddings_table = nn.Parameter( |
|
torch.Tensor(max_relative_position * 2 + 1, num_units) |
|
) |
|
nn.init.xavier_uniform_(self.embeddings_table) |
|
|
|
def forward(self, length_q, length_k): |
|
device = self.embeddings_table.device |
|
range_vec_q = torch.arange(length_q, device=device) |
|
range_vec_k = torch.arange(length_k, device=device) |
|
distance_mat = range_vec_k[None, :] - range_vec_q[:, None] |
|
distance_mat_clipped = torch.clamp( |
|
distance_mat, -self.max_relative_position, self.max_relative_position |
|
) |
|
final_mat = distance_mat_clipped + self.max_relative_position |
|
|
|
|
|
final_mat = final_mat.long() |
|
embeddings = self.embeddings_table[final_mat] |
|
return embeddings |
|
|
|
|
|
class TemporalCrossAttention(nn.Module): |
|
def __init__( |
|
self, |
|
query_dim, |
|
context_dim=None, |
|
heads=8, |
|
dim_head=64, |
|
dropout=0.0, |
|
temporal_length=None, |
|
image_length=None, |
|
use_relative_position=False, |
|
img_video_joint_train=False, |
|
use_tempoal_causal_attn=False, |
|
bidirectional_causal_attn=False, |
|
tempoal_attn_type=None, |
|
joint_train_mode="same_batch", |
|
**kwargs, |
|
): |
|
super().__init__() |
|
inner_dim = dim_head * heads |
|
context_dim = default(context_dim, query_dim) |
|
self.context_dim = context_dim |
|
|
|
self.scale = dim_head**-0.5 |
|
self.heads = heads |
|
self.temporal_length = temporal_length |
|
self.use_relative_position = use_relative_position |
|
self.img_video_joint_train = img_video_joint_train |
|
self.bidirectional_causal_attn = bidirectional_causal_attn |
|
self.joint_train_mode = joint_train_mode |
|
assert joint_train_mode in ["same_batch", "diff_batch"] |
|
self.tempoal_attn_type = tempoal_attn_type |
|
|
|
if bidirectional_causal_attn: |
|
assert use_tempoal_causal_attn |
|
if tempoal_attn_type: |
|
assert tempoal_attn_type in ["sparse_causal", "sparse_causal_first"] |
|
assert not use_tempoal_causal_attn |
|
assert not ( |
|
img_video_joint_train and (self.joint_train_mode == "same_batch") |
|
) |
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) |
|
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) |
|
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) |
|
|
|
assert not ( |
|
img_video_joint_train |
|
and (self.joint_train_mode == "same_batch") |
|
and use_tempoal_causal_attn |
|
) |
|
if img_video_joint_train: |
|
if self.joint_train_mode == "same_batch": |
|
mask = torch.ones( |
|
[1, temporal_length + image_length, temporal_length + image_length] |
|
) |
|
|
|
|
|
mask[:, temporal_length:, :] = 0 |
|
mask[:, :, temporal_length:] = 0 |
|
self.mask = mask |
|
else: |
|
self.mask = None |
|
elif use_tempoal_causal_attn: |
|
|
|
self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length])) |
|
elif tempoal_attn_type == "sparse_causal": |
|
|
|
mask1 = torch.tril( |
|
torch.ones([1, temporal_length, temporal_length]) |
|
).bool() |
|
mask2 = torch.zeros( |
|
[1, temporal_length, temporal_length] |
|
) |
|
mask2[:, 2:temporal_length, : temporal_length - 2] = torch.tril( |
|
torch.ones([1, temporal_length - 2, temporal_length - 2]) |
|
) |
|
mask2 = (1 - mask2).bool() |
|
self.mask = mask1 & mask2 |
|
elif tempoal_attn_type == "sparse_causal_first": |
|
|
|
mask1 = torch.tril( |
|
torch.ones([1, temporal_length, temporal_length]) |
|
).bool() |
|
mask2 = torch.zeros([1, temporal_length, temporal_length]) |
|
mask2[:, 2:temporal_length, 1 : temporal_length - 1] = torch.tril( |
|
torch.ones([1, temporal_length - 2, temporal_length - 2]) |
|
) |
|
mask2 = (1 - mask2).bool() |
|
self.mask = mask1 & mask2 |
|
else: |
|
self.mask = None |
|
|
|
if use_relative_position: |
|
assert temporal_length is not None |
|
self.relative_position_k = RelativePosition( |
|
num_units=dim_head, max_relative_position=temporal_length |
|
) |
|
self.relative_position_v = RelativePosition( |
|
num_units=dim_head, max_relative_position=temporal_length |
|
) |
|
|
|
self.to_out = nn.Sequential( |
|
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) |
|
) |
|
|
|
nn.init.constant_(self.to_q.weight, 0) |
|
nn.init.constant_(self.to_k.weight, 0) |
|
nn.init.constant_(self.to_v.weight, 0) |
|
nn.init.constant_(self.to_out[0].weight, 0) |
|
nn.init.constant_(self.to_out[0].bias, 0) |
|
|
|
def forward(self, x, context=None, mask=None): |
|
|
|
|
|
|
|
|
|
|
|
nh = self.heads |
|
out = x |
|
q = self.to_q(out) |
|
|
|
|
|
|
|
context = default(context, x) |
|
|
|
k = self.to_k(context) |
|
v = self.to_v(context) |
|
|
|
|
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=nh), (q, k, v)) |
|
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale |
|
|
|
if self.use_relative_position: |
|
len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1] |
|
k2 = self.relative_position_k(len_q, len_k) |
|
sim2 = einsum("b t d, t s d -> b t s", q, k2) * self.scale |
|
sim += sim2 |
|
|
|
if exists(self.mask): |
|
if mask is None: |
|
mask = self.mask.to(sim.device) |
|
else: |
|
mask = self.mask.to(sim.device).bool() & mask |
|
else: |
|
mask = mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if mask is not None: |
|
max_neg_value = -1e9 |
|
sim = sim + (1 - mask.float()) * max_neg_value |
|
|
|
|
|
|
|
|
|
|
|
attn = sim.softmax(dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = einsum("b i j, b j d -> b i d", attn, v) |
|
|
|
if self.bidirectional_causal_attn: |
|
mask_reverse = torch.triu( |
|
torch.ones( |
|
[1, self.temporal_length, self.temporal_length], device=sim.device |
|
) |
|
) |
|
sim_reverse = sim.float().masked_fill(mask_reverse == 0, max_neg_value) |
|
attn_reverse = sim_reverse.softmax(dim=-1) |
|
out_reverse = einsum("b i j, b j d -> b i d", attn_reverse, v) |
|
out += out_reverse |
|
|
|
if self.use_relative_position: |
|
v2 = self.relative_position_v(len_q, len_v) |
|
out2 = einsum("b t s, t s d -> b t d", attn, v2) |
|
out += out2 |
|
out = rearrange(out, "(b h) n d -> b n (h d)", h=nh) |
|
return self.to_out(out) |
|
|
|
|
|
class SpatialSelfAttention(nn.Module): |
|
def __init__(self, in_channels): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
|
|
self.norm = Normalize(in_channels) |
|
self.q = torch.nn.Conv2d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.k = torch.nn.Conv2d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.v = torch.nn.Conv2d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.proj_out = torch.nn.Conv2d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
|
|
def forward(self, x): |
|
h_ = x |
|
h_ = self.norm(h_) |
|
q = self.q(h_) |
|
k = self.k(h_) |
|
v = self.v(h_) |
|
|
|
|
|
b, c, h, w = q.shape |
|
q = rearrange(q, "b c h w -> b (h w) c") |
|
k = rearrange(k, "b c h w -> b c (h w)") |
|
w_ = torch.einsum("bij,bjk->bik", q, k) |
|
|
|
w_ = w_ * (int(c) ** (-0.5)) |
|
w_ = torch.nn.functional.softmax(w_, dim=2) |
|
|
|
|
|
v = rearrange(v, "b c h w -> b c (h w)") |
|
w_ = rearrange(w_, "b i j -> b j i") |
|
h_ = torch.einsum("bij,bjk->bik", v, w_) |
|
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) |
|
h_ = self.proj_out(h_) |
|
|
|
return x + h_ |
|
|
|
|
|
class CrossAttention(nn.Module): |
|
def __init__( |
|
self, |
|
query_dim, |
|
context_dim=None, |
|
heads=8, |
|
dim_head=64, |
|
dropout=0.0, |
|
sa_shared_kv=False, |
|
shared_type="only_first", |
|
**kwargs, |
|
): |
|
super().__init__() |
|
inner_dim = dim_head * heads |
|
context_dim = default(context_dim, query_dim) |
|
self.sa_shared_kv = sa_shared_kv |
|
assert shared_type in [ |
|
"only_first", |
|
"all_frames", |
|
"first_and_prev", |
|
"only_prev", |
|
"full", |
|
"causal", |
|
"full_qkv", |
|
] |
|
self.shared_type = shared_type |
|
|
|
self.scale = dim_head**-0.5 |
|
self.heads = heads |
|
self.dim_head = dim_head |
|
|
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) |
|
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) |
|
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) |
|
|
|
self.to_out = nn.Sequential( |
|
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) |
|
) |
|
self.attention_op: Optional[Any] = None |
|
|
|
def forward(self, x, context=None, mask=None): |
|
h = self.heads |
|
b = x.shape[0] |
|
|
|
q = self.to_q(x) |
|
context = default(context, x) |
|
k = self.to_k(context) |
|
v = self.to_v(context) |
|
if self.sa_shared_kv: |
|
if self.shared_type == "only_first": |
|
k, v = map( |
|
lambda xx: rearrange(xx[0].unsqueeze(0), "b n c -> (b n) c") |
|
.unsqueeze(0) |
|
.repeat(b, 1, 1), |
|
(k, v), |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) |
|
|
|
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale |
|
|
|
if exists(mask): |
|
mask = rearrange(mask, "b ... -> b (...)") |
|
max_neg_value = -torch.finfo(sim.dtype).max |
|
mask = repeat(mask, "b j -> (b h) () j", h=h) |
|
sim.masked_fill_(~mask, max_neg_value) |
|
|
|
|
|
attn = sim.softmax(dim=-1) |
|
|
|
out = einsum("b i j, b j d -> b i d", attn, v) |
|
out = rearrange(out, "(b h) n d -> b n (h d)", h=h) |
|
return self.to_out(out) |
|
|
|
def efficient_forward(self, x, context=None, mask=None): |
|
q = self.to_q(x) |
|
context = default(context, x) |
|
k = self.to_k(context) |
|
v = self.to_v(context) |
|
|
|
b, _, _ = q.shape |
|
q, k, v = map( |
|
lambda t: t.unsqueeze(3) |
|
.reshape(b, t.shape[1], self.heads, self.dim_head) |
|
.permute(0, 2, 1, 3) |
|
.reshape(b * self.heads, t.shape[1], self.dim_head) |
|
.contiguous(), |
|
(q, k, v), |
|
) |
|
|
|
out = xformers.ops.memory_efficient_attention( |
|
q, k, v, attn_bias=None, op=self.attention_op |
|
) |
|
|
|
if exists(mask): |
|
raise NotImplementedError |
|
out = ( |
|
out.unsqueeze(0) |
|
.reshape(b, self.heads, out.shape[1], self.dim_head) |
|
.permute(0, 2, 1, 3) |
|
.reshape(b, out.shape[1], self.heads * self.dim_head) |
|
) |
|
return self.to_out(out) |
|
|
|
|
|
class VideoSpatialCrossAttention(CrossAttention): |
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0): |
|
super().__init__(query_dim, context_dim, heads, dim_head, dropout) |
|
|
|
def forward(self, x, context=None, mask=None): |
|
b, c, t, h, w = x.shape |
|
if context is not None: |
|
context = context.repeat(t, 1, 1) |
|
x = super.forward(spatial_attn_reshape(x), context=context) + x |
|
return spatial_attn_reshape_back(x, b, h) |
|
|
|
|
|
def spatial_attn_reshape(x): |
|
return rearrange(x, "b c t h w -> (b t) (h w) c") |
|
|
|
|
|
def spatial_attn_reshape_back(x, b, h): |
|
return rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h) |
|
|
|
|
|
def temporal_attn_reshape(x): |
|
return rearrange(x, "b c t h w -> (b h w) t c") |
|
|
|
|
|
def temporal_attn_reshape_back(x, b, h, w): |
|
return rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) |
|
|
|
|
|
def local_spatial_temporal_attn_reshape(x, window_size): |
|
B, C, T, H, W = x.shape |
|
NH = H // window_size |
|
NW = W // window_size |
|
|
|
|
|
|
|
x = rearrange( |
|
x, |
|
"b c t (nh wh) (nw ww) -> b c t nh wh nw ww", |
|
nh=NH, |
|
nw=NW, |
|
wh=window_size, |
|
ww=window_size, |
|
).contiguous() |
|
x = rearrange( |
|
x, "b c t nh wh nw ww -> (b nh nw) (t wh ww) c" |
|
) |
|
return x |
|
|
|
|
|
def local_spatial_temporal_attn_reshape_back(x, window_size, b, h, w, t): |
|
B, L, C = x.shape |
|
NH = h // window_size |
|
NW = w // window_size |
|
x = rearrange( |
|
x, |
|
"(b nh nw) (t wh ww) c -> b c t nh wh nw ww", |
|
b=b, |
|
nh=NH, |
|
nw=NW, |
|
t=t, |
|
wh=window_size, |
|
ww=window_size, |
|
) |
|
x = rearrange(x, "b c t nh wh nw ww -> b c t (nh wh) (nw ww)") |
|
return x |
|
|
|
|
|
class SpatialTemporalTransformer(nn.Module): |
|
""" |
|
Transformer block for video-like data (5D tensor). |
|
First, project the input (aka embedding) with NO reshape. |
|
Then apply standard transformer action. |
|
The 5D -> 3D reshape operation will be done in the specific attention module. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
n_heads, |
|
d_head, |
|
depth=1, |
|
dropout=0.0, |
|
context_dim=None, |
|
|
|
temporal_length=None, |
|
image_length=None, |
|
use_relative_position=True, |
|
img_video_joint_train=False, |
|
cross_attn_on_tempoal=False, |
|
temporal_crossattn_type="selfattn", |
|
order="stst", |
|
temporalcrossfirst=False, |
|
split_stcontext=False, |
|
temporal_context_dim=None, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
self.in_channels = in_channels |
|
inner_dim = n_heads * d_head |
|
|
|
self.norm = Normalize(in_channels) |
|
self.proj_in = nn.Conv3d( |
|
in_channels, inner_dim, kernel_size=1, stride=1, padding=0 |
|
) |
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
[ |
|
BasicTransformerBlockST( |
|
inner_dim, |
|
n_heads, |
|
d_head, |
|
dropout=dropout, |
|
|
|
context_dim=context_dim, |
|
|
|
temporal_length=temporal_length, |
|
image_length=image_length, |
|
use_relative_position=use_relative_position, |
|
img_video_joint_train=img_video_joint_train, |
|
temporal_crossattn_type=temporal_crossattn_type, |
|
order=order, |
|
temporalcrossfirst=temporalcrossfirst, |
|
split_stcontext=split_stcontext, |
|
temporal_context_dim=temporal_context_dim, |
|
**kwargs, |
|
) |
|
for d in range(depth) |
|
] |
|
) |
|
|
|
self.proj_out = zero_module( |
|
nn.Conv3d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) |
|
) |
|
|
|
def forward(self, x, context=None, temporal_context=None, **kwargs): |
|
|
|
assert x.dim() == 5, f"x shape = {x.shape}" |
|
b, c, t, h, w = x.shape |
|
x_in = x |
|
|
|
x = self.norm(x) |
|
x = self.proj_in(x) |
|
|
|
for block in self.transformer_blocks: |
|
x = block(x, context=context, temporal_context=temporal_context, **kwargs) |
|
|
|
x = self.proj_out(x) |
|
return x + x_in |
|
|
|
|
|
class STAttentionBlock2(nn.Module): |
|
def __init__( |
|
self, |
|
channels, |
|
num_heads=1, |
|
num_head_channels=-1, |
|
use_checkpoint=False, |
|
use_new_attention_order=False, |
|
temporal_length=16, |
|
image_length=8, |
|
use_relative_position=False, |
|
img_video_joint_train=False, |
|
|
|
attn_norm_type="group", |
|
use_tempoal_causal_attn=False, |
|
): |
|
""" |
|
version 1: guided_diffusion implemented version |
|
version 2: remove args input argument |
|
""" |
|
super().__init__() |
|
|
|
if num_head_channels == -1: |
|
self.num_heads = num_heads |
|
else: |
|
assert ( |
|
channels % num_head_channels == 0 |
|
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" |
|
self.num_heads = channels // num_head_channels |
|
self.use_checkpoint = use_checkpoint |
|
|
|
self.temporal_length = temporal_length |
|
self.image_length = image_length |
|
self.use_relative_position = use_relative_position |
|
self.img_video_joint_train = img_video_joint_train |
|
self.attn_norm_type = attn_norm_type |
|
assert self.attn_norm_type in ["group", "no_norm"] |
|
self.use_tempoal_causal_attn = use_tempoal_causal_attn |
|
|
|
if self.attn_norm_type == "group": |
|
self.norm_s = normalization(channels) |
|
self.norm_t = normalization(channels) |
|
|
|
self.qkv_s = conv_nd(1, channels, channels * 3, 1) |
|
self.qkv_t = conv_nd(1, channels, channels * 3, 1) |
|
|
|
if self.img_video_joint_train: |
|
mask = torch.ones( |
|
[1, temporal_length + image_length, temporal_length + image_length] |
|
) |
|
mask[:, temporal_length:, :] = 0 |
|
mask[:, :, temporal_length:] = 0 |
|
self.register_buffer("mask", mask) |
|
else: |
|
self.mask = None |
|
|
|
if use_new_attention_order: |
|
|
|
self.attention_s = QKVAttention(self.num_heads) |
|
self.attention_t = QKVAttention(self.num_heads) |
|
else: |
|
|
|
self.attention_s = QKVAttentionLegacy(self.num_heads) |
|
self.attention_t = QKVAttentionLegacy(self.num_heads) |
|
|
|
if use_relative_position: |
|
self.relative_position_k = RelativePosition( |
|
num_units=channels // self.num_heads, |
|
max_relative_position=temporal_length, |
|
) |
|
self.relative_position_v = RelativePosition( |
|
num_units=channels // self.num_heads, |
|
max_relative_position=temporal_length, |
|
) |
|
|
|
self.proj_out_s = zero_module( |
|
conv_nd(1, channels, channels, 1) |
|
) |
|
self.proj_out_t = zero_module( |
|
conv_nd(1, channels, channels, 1) |
|
) |
|
|
|
def forward(self, x, mask=None): |
|
b, c, t, h, w = x.shape |
|
|
|
|
|
out = rearrange(x, "b c t h w -> (b t) c (h w)") |
|
if self.attn_norm_type == "no_norm": |
|
qkv = self.qkv_s(out) |
|
else: |
|
qkv = self.qkv_s(self.norm_s(out)) |
|
out = self.attention_s(qkv) |
|
out = self.proj_out_s(out) |
|
out = rearrange(out, "(b t) c (h w) -> b c t h w", b=b, h=h) |
|
x += out |
|
|
|
|
|
out = rearrange(x, "b c t h w -> (b h w) c t") |
|
if self.attn_norm_type == "no_norm": |
|
qkv = self.qkv_t(out) |
|
else: |
|
qkv = self.qkv_t(self.norm_t(out)) |
|
|
|
|
|
if self.use_relative_position: |
|
len_q = qkv.size()[-1] |
|
len_k, len_v = len_q, len_q |
|
k_rp = self.relative_position_k(len_q, len_k) |
|
v_rp = self.relative_position_v(len_q, len_v) |
|
out = self.attention_t( |
|
qkv, |
|
rp=(k_rp, v_rp), |
|
mask=self.mask, |
|
use_tempoal_causal_attn=self.use_tempoal_causal_attn, |
|
) |
|
else: |
|
out = self.attention_t( |
|
qkv, |
|
rp=None, |
|
mask=self.mask, |
|
use_tempoal_causal_attn=self.use_tempoal_causal_attn, |
|
) |
|
|
|
out = self.proj_out_t(out) |
|
out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w) |
|
|
|
return x + out |
|
|
|
|
|
class QKVAttentionLegacy(nn.Module): |
|
""" |
|
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping |
|
""" |
|
|
|
def __init__(self, n_heads): |
|
super().__init__() |
|
self.n_heads = n_heads |
|
|
|
def forward(self, qkv, rp=None, mask=None): |
|
""" |
|
Apply QKV attention. |
|
|
|
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. |
|
:return: an [N x (H * C) x T] tensor after attention. |
|
""" |
|
if rp is not None or mask is not None: |
|
raise NotImplementedError |
|
bs, width, length = qkv.shape |
|
assert width % (3 * self.n_heads) == 0 |
|
ch = width // (3 * self.n_heads) |
|
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) |
|
scale = 1 / math.sqrt(math.sqrt(ch)) |
|
weight = torch.einsum( |
|
"bct,bcs->bts", q * scale, k * scale |
|
) |
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) |
|
a = torch.einsum("bts,bcs->bct", weight, v) |
|
return a.reshape(bs, -1, length) |
|
|
|
@staticmethod |
|
def count_flops(model, _x, y): |
|
return count_flops_attn(model, _x, y) |
|
|
|
|
|
class QKVAttention(nn.Module): |
|
""" |
|
A module which performs QKV attention and splits in a different order. |
|
""" |
|
|
|
def __init__(self, n_heads): |
|
super().__init__() |
|
self.n_heads = n_heads |
|
|
|
def forward(self, qkv, rp=None, mask=None, use_tempoal_causal_attn=False): |
|
""" |
|
Apply QKV attention. |
|
|
|
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. |
|
:return: an [N x (H * C) x T] tensor after attention. |
|
""" |
|
bs, width, length = qkv.shape |
|
assert width % (3 * self.n_heads) == 0 |
|
ch = width // (3 * self.n_heads) |
|
|
|
qkv=qkv.contiguous() |
|
q, k, v = qkv.chunk(3, dim=1) |
|
scale = 1 / math.sqrt(math.sqrt(ch)) |
|
|
|
|
|
weight = torch.einsum( |
|
"bct,bcs->bts", |
|
(q * scale).view(bs * self.n_heads, ch, length), |
|
(k * scale).view(bs * self.n_heads, ch, length), |
|
) |
|
|
|
|
|
if rp is not None: |
|
k_rp, v_rp = rp |
|
weight2 = torch.einsum( |
|
"bct,tsc->bst", (q * scale).view(bs * self.n_heads, ch, length), k_rp |
|
) |
|
weight += weight2 |
|
|
|
if use_tempoal_causal_attn: |
|
|
|
assert mask is None, f"Not implemented for merging two masks!" |
|
mask = torch.tril(torch.ones(weight.shape)) |
|
else: |
|
if mask is not None: |
|
|
|
c, t, _ = weight.shape |
|
|
|
if mask.shape[-1] > t: |
|
mask = mask[:, :t, :t] |
|
elif mask.shape[-1] < t: |
|
mask_ = torch.zeros([c, t, t]).to(mask.device) |
|
t_ = mask.shape[-1] |
|
mask_[:, :t_, :t_] = mask |
|
mask = mask_ |
|
else: |
|
assert ( |
|
weight.shape[-1] == mask.shape[-1] |
|
), f"weight={weight.shape}, mask={mask.shape}" |
|
|
|
if mask is not None: |
|
INF = -1e8 |
|
weight = weight.float().masked_fill(mask == 0, INF) |
|
|
|
weight = F.softmax(weight.float(), dim=-1).type( |
|
weight.dtype |
|
) |
|
|
|
a = torch.einsum( |
|
"bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length) |
|
) |
|
|
|
if rp is not None: |
|
a2 = torch.einsum("bts,tsc->btc", weight, v_rp).transpose(1, 2) |
|
a += a2 |
|
|
|
return a.reshape(bs, -1, length) |
|
|
|
|
|
def silu(x): |
|
|
|
return x * torch.sigmoid(x) |
|
|
|
|
|
class SiLU(nn.Module): |
|
def __init__(self): |
|
super(SiLU, self).__init__() |
|
|
|
def forward(self, x): |
|
return silu(x) |
|
|
|
|
|
def Normalize(in_channels, norm_type="group"): |
|
assert norm_type in ["group", "batch",'layer'] |
|
if norm_type == "group": |
|
return torch.nn.GroupNorm( |
|
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True |
|
) |
|
elif norm_type == "batch": |
|
return torch.nn.SyncBatchNorm(in_channels) |
|
elif norm_type == "layer": |
|
return nn.LayerNorm(in_channels) |
|
|
|
class SamePadConv3d(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=1, |
|
bias=True, |
|
padding_type="replicate", |
|
): |
|
super().__init__() |
|
if isinstance(kernel_size, int): |
|
kernel_size = (kernel_size,) * 3 |
|
if isinstance(stride, int): |
|
stride = (stride,) * 3 |
|
|
|
|
|
total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) |
|
pad_input = [] |
|
for p in total_pad[::-1]: |
|
pad_input.append((p // 2 + p % 2, p // 2)) |
|
pad_input = sum(pad_input, tuple()) |
|
|
|
self.pad_input = pad_input |
|
self.padding_type = padding_type |
|
|
|
self.conv = nn.Conv3d( |
|
in_channels, out_channels, kernel_size, stride=stride, padding=0, bias=bias |
|
) |
|
|
|
def forward(self, x): |
|
tp=x.dtype |
|
x = x.float() |
|
|
|
|
|
x_padded = F.pad(x, self.pad_input, mode=self.padding_type) |
|
|
|
|
|
x_padded = x_padded.to(tp) |
|
|
|
return self.conv(x_padded) |
|
|
|
class TemporalAttention(nn.Module): |
|
def __init__( |
|
self, |
|
channels, |
|
num_heads=1, |
|
num_head_channels=-1, |
|
max_temporal_length=64, |
|
): |
|
""" |
|
a clean multi-head temporal attention |
|
""" |
|
super().__init__() |
|
|
|
if num_head_channels == -1: |
|
self.num_heads = num_heads |
|
else: |
|
assert ( |
|
channels % num_head_channels == 0 |
|
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" |
|
self.num_heads = channels // num_head_channels |
|
|
|
self.norm = Normalize(channels) |
|
self.qkv = zero_module(conv_nd(1, channels, channels * 3, 1)) |
|
self.attention = QKVAttention(self.num_heads) |
|
self.relative_position_k = RelativePosition( |
|
num_units=channels // self.num_heads, |
|
max_relative_position=max_temporal_length, |
|
) |
|
self.relative_position_v = RelativePosition( |
|
num_units=channels // self.num_heads, |
|
max_relative_position=max_temporal_length, |
|
) |
|
self.proj_out = zero_module( |
|
conv_nd(1, channels, channels, 1) |
|
) |
|
|
|
def forward(self, x, mask=None): |
|
b, c, t, h, w = x.shape |
|
out = rearrange(x, "b c t h w -> (b h w) c t") |
|
|
|
|
|
|
|
|
|
|
|
qkv = self.qkv(self.norm(out)) |
|
|
|
|
|
len_q = qkv.size()[-1] |
|
len_k, len_v = len_q, len_q |
|
|
|
k_rp = self.relative_position_k(len_q, len_k) |
|
v_rp = self.relative_position_v(len_q, len_v) |
|
out = self.attention(qkv, rp=(k_rp, v_rp)) |
|
|
|
out = self.proj_out(out) |
|
|
|
out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w) |
|
|
|
return x + out |
|
class TemporalAttention_lin(nn.Module): |
|
def __init__( |
|
self, |
|
channels, |
|
num_heads=8, |
|
num_head_channels=-1, |
|
max_temporal_length=64, |
|
): |
|
""" |
|
a clean multi-head temporal attention |
|
""" |
|
super().__init__() |
|
|
|
if num_head_channels == -1: |
|
self.num_heads = num_heads |
|
else: |
|
assert ( |
|
channels % num_head_channels == 0 |
|
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" |
|
self.num_heads = channels // num_head_channels |
|
|
|
self.norm = nn.LayerNorm(channels) |
|
|
|
|
|
self.qkv = nn.Linear(channels, channels * 3) |
|
self.attention = QKVAttention(self.num_heads) |
|
self.relative_position_k = RelativePosition( |
|
num_units=channels // self.num_heads, |
|
max_relative_position=max_temporal_length, |
|
) |
|
self.relative_position_v = RelativePosition( |
|
num_units=channels // self.num_heads, |
|
max_relative_position=max_temporal_length, |
|
) |
|
self.proj_out = nn.Linear(channels, channels) |
|
|
|
def forward(self, x, mask=None): |
|
b, c, t, h, w = x.shape |
|
out = rearrange(x, "b c t h w -> (b h w) t c") |
|
|
|
|
|
|
|
|
|
|
|
qkv = self.qkv(self.norm(out)).transpose(-1, -2) |
|
|
|
|
|
len_q = qkv.size()[-1] |
|
len_k, len_v = len_q, len_q |
|
|
|
k_rp = self.relative_position_k(len_q, len_k) |
|
v_rp = self.relative_position_v(len_q, len_v) |
|
|
|
out = self.attention(qkv, rp=(k_rp, v_rp)) |
|
|
|
out = self.proj_out(out.transpose(-1, -2)).transpose(-1, -2) |
|
|
|
|
|
out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w) |
|
|
|
return x + out |
|
|
|
class AttnBlock3D(nn.Module): |
|
def __init__(self, in_channels): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
|
|
self.norm = Normalize(in_channels) |
|
self.q = torch.nn.Conv3d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.k = torch.nn.Conv3d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.v = torch.nn.Conv3d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.proj_out = torch.nn.Conv3d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
|
|
def forward(self, x): |
|
h_ = x |
|
|
|
|
|
h_ = self.norm(h_) |
|
q = self.q(h_) |
|
k = self.k(h_) |
|
v = self.v(h_) |
|
|
|
b, c, t, h, w = q.shape |
|
|
|
|
|
|
|
q = rearrange(q, "b c t h w -> (b t) (h w) c") |
|
k = rearrange(k, "b c t h w -> (b t) c (h w)") |
|
|
|
w_ = torch.bmm(q, k) |
|
w_ = w_ * (int(c) ** (-0.5)) |
|
w_ = torch.nn.functional.softmax(w_, dim=2) |
|
|
|
|
|
v = rearrange(v, "b c t h w -> (b t) c (h w)") |
|
|
|
|
|
w_ = w_.permute(0, 2, 1) |
|
h_ = torch.bmm(v, w_) |
|
|
|
|
|
h_ = rearrange(h_, "(b t) c (h w) -> b c t h w", b=b, h=h) |
|
|
|
h_ = self.proj_out(h_) |
|
|
|
return x + h_ |
|
|
|
class MultiHeadAttention3D(nn.Module): |
|
def __init__(self, in_channels, num_heads=8): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.num_heads = num_heads |
|
self.head_dim = in_channels // num_heads |
|
|
|
assert self.head_dim * num_heads == in_channels, "in_channels must be divisible by num_heads" |
|
|
|
self.norm = nn.LayerNorm(in_channels) |
|
self.q_linear = nn.Linear(in_channels, in_channels) |
|
self.k_linear = nn.Linear(in_channels, in_channels) |
|
self.v_linear = nn.Linear(in_channels, in_channels) |
|
self.proj_out = nn.Linear(in_channels, in_channels) |
|
|
|
def forward(self, x): |
|
b, c, t, h, w = x.shape |
|
|
|
|
|
h_ = rearrange(x, "b c t h w -> (b t) (h w) c") |
|
h_ = self.norm(h_) |
|
|
|
|
|
q = self.q_linear(h_) |
|
k = self.k_linear(h_) |
|
v = self.v_linear(h_) |
|
|
|
|
|
q = rearrange(q, "b l (h d) -> b h l d", h=self.num_heads) |
|
k = rearrange(k, "b l (h d) -> b h l d", h=self.num_heads) |
|
v = rearrange(v, "b l (h d) -> b h l d", h=self.num_heads) |
|
|
|
|
|
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5) |
|
attn = F.softmax(scores, dim=-1) |
|
|
|
|
|
out = torch.matmul(attn, v) |
|
out = rearrange(out, "b h l d -> b l (h d)") |
|
|
|
|
|
out = self.proj_out(out) |
|
|
|
|
|
out = rearrange(out, "(b t) (h w) c -> b c t h w", b=b, h=h, t=t) |
|
|
|
return x + out |
|
|
|
|
|
class SiglipAE(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
temporal_stride=2 |
|
norm_type = "group" |
|
|
|
self.temporal_encoding = nn.Parameter(torch.randn((4,1152))) |
|
|
|
self.encoder=nn.Sequential( |
|
AttnBlock3D(1152), |
|
TemporalAttention(1152), |
|
|
|
SamePadConv3d(1152,1152,kernel_size=3,stride=(temporal_stride, 1, 1),padding_type="replicate"), |
|
|
|
AttnBlock3D(1152), |
|
TemporalAttention(1152), |
|
|
|
SamePadConv3d(1152,1152,kernel_size=3,stride=(temporal_stride, 1, 1),padding_type="replicate"), |
|
|
|
) |
|
def forward(self, x): |
|
b_,c_,t_,h_,w_=x.shape |
|
|
|
temporal_encoding = self.temporal_encoding.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) |
|
temporal_encoding = temporal_encoding.expand(b_, -1, -1, h_, w_) |
|
temporal_encoding = temporal_encoding.permute(0, 2, 1, 3, 4) |
|
x = x + temporal_encoding |
|
|
|
x=self.encoder(x) |
|
return x |
|
|
|
|