|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.utils.checkpoint as checkpoint |
|
from torch import nn |
|
from collections import OrderedDict |
|
from timm.models.layers import trunc_normal_ |
|
|
|
class Attention(nn.Module): |
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): |
|
super().__init__() |
|
self.num_heads = num_heads |
|
head_dim = dim // num_heads |
|
|
|
self.scale = qk_scale or head_dim ** -0.5 |
|
|
|
self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) |
|
self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) |
|
self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) |
|
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop) |
|
self.proj = nn.Linear(dim, dim) |
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
def forward(self, q, k, v): |
|
B, N, C = q.shape |
|
assert k.shape == v.shape |
|
B, M, C = k.shape |
|
q = self.q_proj(q).reshape(B, N, self.num_heads, C // self.num_heads) |
|
k = self.k_proj(k).reshape(B, M, self.num_heads, C // self.num_heads) |
|
v = self.v_proj(v).reshape(B, M, self.num_heads, C // self.num_heads) |
|
|
|
attn = torch.einsum('bnkc,bmkc->bknm', q, k) * self.scale |
|
|
|
attn = attn.softmax(dim=-1) |
|
|
|
x = torch.einsum('bknm,bmkc->bnkc', attn, v).reshape(B, N, C) |
|
|
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
return x |
|
|
|
class TransformerDecoderLayer(nn.Module): |
|
def __init__( |
|
self, |
|
d_model, |
|
nhead, |
|
dropout=0.1, |
|
): |
|
super().__init__() |
|
self.self_attn = Attention(d_model, nhead, proj_drop=dropout) |
|
self.cross_attn = Attention(d_model, nhead, proj_drop=dropout) |
|
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
self.norm3 = nn.LayerNorm(d_model) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
self.mlp = nn.Sequential( |
|
nn.Linear(d_model, d_model * 4), |
|
nn.GELU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(d_model * 4, d_model) |
|
) |
|
|
|
def forward(self, x, mem): |
|
q = k = v = self.norm1(x) |
|
x = x + self.self_attn(q, k, v) |
|
q = self.norm2(x) |
|
x = x + self.cross_attn(q, mem, mem) |
|
x = x + self.dropout(self.mlp(self.norm3(x))) |
|
return x |
|
|
|
|
|
class ContextDecoder(nn.Module): |
|
def __init__(self, |
|
transformer_width=256, |
|
transformer_heads=4, |
|
transformer_layers=6, |
|
visual_dim=1024, |
|
dropout=0.1, |
|
**kwargs): |
|
super().__init__() |
|
|
|
self.memory_proj = nn.Sequential( |
|
nn.LayerNorm(visual_dim), |
|
nn.Linear(visual_dim, transformer_width), |
|
nn.LayerNorm(transformer_width), |
|
) |
|
|
|
self.text_proj = nn.Sequential( |
|
nn.LayerNorm(visual_dim), |
|
nn.Linear(visual_dim, transformer_width), |
|
) |
|
|
|
self.decoder = nn.ModuleList([ |
|
TransformerDecoderLayer(transformer_width, transformer_heads, dropout) for _ in range(transformer_layers) |
|
]) |
|
|
|
self.out_proj = nn.Sequential( |
|
nn.LayerNorm(transformer_width), |
|
nn.Linear(transformer_width, visual_dim) |
|
) |
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
|
|
def forward(self, text, visual): |
|
B, N, C = visual.shape |
|
visual = self.memory_proj(visual) |
|
x = self.text_proj(text) |
|
|
|
for layer in self.decoder: |
|
x = layer(x, visual) |
|
|
|
return self.out_proj(x) |
|
|
|
|
|
class QuickGELU(nn.Module): |
|
|
|
def forward(self, x: torch.Tensor): |
|
return x * torch.sigmoid(1.702 * x) |
|
|
|
|
|
class ResidualAttentionBlock(nn.Module): |
|
|
|
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): |
|
super().__init__() |
|
|
|
self.attn = nn.MultiheadAttention(d_model, n_head) |
|
self.ln_1 = nn.LayerNorm(d_model) |
|
self.mlp = nn.Sequential( |
|
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), ('gelu', QuickGELU()), |
|
('c_proj', nn.Linear(d_model * 4, d_model))])) |
|
self.ln_2 = nn.LayerNorm(d_model) |
|
self.attn_mask = attn_mask |
|
|
|
def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor): |
|
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None |
|
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask, key_padding_mask=key_padding_mask)[0] |
|
|
|
def forward(self, x: torch.Tensor, key_padding_mask=None): |
|
x = x + self.attention(self.ln_1(x), key_padding_mask=key_padding_mask) |
|
x = x + self.mlp(self.ln_2(x)) |
|
return x |
|
|
|
class Transformer(nn.Module): |
|
|
|
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_checkpoint=False): |
|
super().__init__() |
|
self.width = width |
|
self.layers = layers |
|
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) |
|
proj_std = (self.width**-0.5) * ((2 * self.layers)**-0.5) |
|
attn_std = self.width**-0.5 |
|
fc_std = (2 * self.width)**-0.5 |
|
for block in self.resblocks: |
|
nn.init.normal_(block.attn.in_proj_weight, std=attn_std) |
|
nn.init.normal_(block.attn.out_proj.weight, std=proj_std) |
|
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) |
|
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) |
|
|
|
self.use_checkpoint = use_checkpoint |
|
|
|
def forward(self, x: torch.Tensor): |
|
for resblock in self.resblocks: |
|
if self.use_checkpoint: |
|
x = checkpoint.checkpoint(resblock, x) |
|
else: |
|
x = resblock(x) |
|
return x |
|
|
|
|
|
class TextTransformer(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
context_length: int, |
|
width: int, |
|
layers: int, |
|
vocab_size, |
|
use_checkpoint=False, |
|
): |
|
|
|
super().__init__() |
|
heads = width // 64 |
|
self.context_length = context_length |
|
self.width = width |
|
self.transformer = Transformer( |
|
width=width, |
|
layers=layers, |
|
heads=heads, |
|
attn_mask=self.build_attention_mask(), |
|
use_checkpoint=use_checkpoint) |
|
|
|
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width)) |
|
self.ln_final = nn.LayerNorm(width) |
|
self.token_embedding = nn.Embedding(vocab_size, width) |
|
nn.init.normal_(self.token_embedding.weight, std=0.02) |
|
|
|
|
|
nn.init.normal_(self.positional_embedding, std=0.01) |
|
|
|
def build_attention_mask(self): |
|
|
|
|
|
mask = torch.empty(self.context_length, self.context_length) |
|
mask.fill_(float('-inf')) |
|
mask.triu_(1) |
|
return mask |
|
|
|
def forward(self, text): |
|
x = self.token_embedding(text) |
|
x = x + self.positional_embedding |
|
x = x.permute(1, 0, 2) |
|
x = self.transformer(x) |
|
x = x.permute(1, 0, 2) |
|
x = self.ln_final(x) |
|
|
|
|
|
|
|
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] |
|
|
|
return x |