# Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] import torch import torch.nn as nn import torch.nn.functional as F class VisionLanguageEmbedding(nn.Module): def __init__(self, text_embed, vision_embed): super().__init__() self.text_embed = text_embed self.vision_embed = vision_embed def forward(self, textual_tokens, visual_tokens, **kwargs): if textual_tokens is None: return self.vision_embed(visual_tokens) if visual_tokens is None: return self.text_embed(textual_tokens) x1 = self.vision_embed(visual_tokens) x2 = self.text_embed(textual_tokens) return torch.cat([x1, x2], dim=1) class VisionEmbedding(nn.Module): """Image to Patch Embedding""" def __init__( self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, contain_mask_token=False, prepend_cls_token=False, ): super().__init__() img_size = (img_size, img_size) patch_size = (patch_size, patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if contain_mask_token: self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) else: self.mask_token = None if prepend_cls_token: self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) else: self.cls_token = None def num_position_embeddings(self): if self.cls_token is None: return self.num_patches else: return self.num_patches + 1 def forward(self, x, masked_position=None, **kwargs): B, C, H, W = x.shape x = self.proj(x).flatten(2).transpose(1, 2) batch_size, seq_len, _ = x.size() if masked_position is not None: assert self.mask_token is not None mask_token = self.mask_token.expand(batch_size, seq_len, -1) w = masked_position.unsqueeze(-1).type_as(mask_token) x = x * (1 - w) + mask_token * w if self.cls_token is not None: cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) return x class TextEmbedding(nn.Embedding): def reset_parameters(self): nn.init.normal_(self.weight, mean=0, std=self.embedding_dim**-0.5) self._fill_padding_idx_with_zero() class PositionalEmbedding(nn.Embedding): def forward( self, x, positions=None, **kwargs, ): if positions is None: # being consistent with Fairseq, which starts from 2. positions = torch.arange(2, x.size(1) + 2, device=x.device).long().unsqueeze(0) return F.embedding( positions, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse, )