|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class VisionLanguageEmbedding(nn.Module): |
|
def __init__(self, text_embed, vision_embed): |
|
super().__init__() |
|
self.text_embed = text_embed |
|
self.vision_embed = vision_embed |
|
|
|
def forward(self, textual_tokens, visual_tokens, **kwargs): |
|
if textual_tokens is None: |
|
return self.vision_embed(visual_tokens) |
|
|
|
if visual_tokens is None: |
|
return self.text_embed(textual_tokens) |
|
|
|
x1 = self.vision_embed(visual_tokens) |
|
x2 = self.text_embed(textual_tokens) |
|
|
|
return torch.cat([x1, x2], dim=1) |
|
|
|
|
|
class VisionEmbedding(nn.Module): |
|
"""Image to Patch Embedding""" |
|
|
|
def __init__( |
|
self, |
|
img_size=224, |
|
patch_size=16, |
|
in_chans=3, |
|
embed_dim=768, |
|
contain_mask_token=False, |
|
prepend_cls_token=False, |
|
): |
|
super().__init__() |
|
img_size = (img_size, img_size) |
|
patch_size = (patch_size, patch_size) |
|
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) |
|
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) |
|
self.img_size = img_size |
|
self.patch_size = patch_size |
|
self.num_patches = num_patches |
|
|
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) |
|
|
|
if contain_mask_token: |
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
else: |
|
self.mask_token = None |
|
|
|
if prepend_cls_token: |
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
else: |
|
self.cls_token = None |
|
|
|
def num_position_embeddings(self): |
|
if self.cls_token is None: |
|
return self.num_patches |
|
else: |
|
return self.num_patches + 1 |
|
|
|
def forward(self, x, masked_position=None, **kwargs): |
|
B, C, H, W = x.shape |
|
x = self.proj(x).flatten(2).transpose(1, 2) |
|
|
|
batch_size, seq_len, _ = x.size() |
|
|
|
if masked_position is not None: |
|
assert self.mask_token is not None |
|
mask_token = self.mask_token.expand(batch_size, seq_len, -1) |
|
w = masked_position.unsqueeze(-1).type_as(mask_token) |
|
x = x * (1 - w) + mask_token * w |
|
|
|
if self.cls_token is not None: |
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
|
return x |
|
|
|
|
|
class TextEmbedding(nn.Embedding): |
|
def reset_parameters(self): |
|
nn.init.normal_(self.weight, mean=0, std=self.embedding_dim**-0.5) |
|
self._fill_padding_idx_with_zero() |
|
|
|
|
|
class PositionalEmbedding(nn.Embedding): |
|
def forward( |
|
self, |
|
x, |
|
positions=None, |
|
**kwargs, |
|
): |
|
if positions is None: |
|
|
|
positions = torch.arange(2, x.size(1) + 2, device=x.device).long().unsqueeze(0) |
|
return F.embedding( |
|
positions, |
|
self.weight, |
|
self.padding_idx, |
|
self.max_norm, |
|
self.norm_type, |
|
self.scale_grad_by_freq, |
|
self.sparse, |
|
) |
|
|