# Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] import torch import torch.nn as nn from vlmo.torchscale.architecture.encoder import Encoder from vlmo.torchscale.component.embedding import ( PositionalEmbedding, TextEmbedding, VisionEmbedding, ) from vlmo.torchscale.component.multiway_network import MutliwayEmbedding class BEiT3(nn.Module): def __init__(self, args, **kwargs): super().__init__() self.args = args assert args.multiway assert args.vocab_size > 0 assert not args.share_encoder_input_output_embed self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim) self.vision_embed = VisionEmbedding( args.img_size, args.patch_size, args.in_chans, args.encoder_embed_dim, contain_mask_token=True, prepend_cls_token=True, ) # being consistent with Fairseq, which starts from 2 for position embedding embed_positions = MutliwayEmbedding( modules=[ PositionalEmbedding(self.vision_embed.num_position_embeddings() + 2, args.encoder_embed_dim), PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim), ], dim=1, ) self.encoder = Encoder( args, embed_tokens=None, embed_positions=embed_positions, output_projection=None, is_encoder_decoder=False, ) def forward( self, textual_tokens=None, visual_tokens=None, text_padding_position=None, attn_mask=None, vision_masked_position=None, incremental_state=None, positions=None, ): assert textual_tokens is not None or visual_tokens is not None if textual_tokens is None: x = self.vision_embed(visual_tokens, vision_masked_position) encoder_padding_mask = None multiway_split_position = -1 elif visual_tokens is None: x = self.text_embed(textual_tokens) encoder_padding_mask = text_padding_position multiway_split_position = 0 else: x1 = self.vision_embed(visual_tokens, vision_masked_position) multiway_split_position = x1.size(1) x2 = self.text_embed(textual_tokens) diff = x1.shape[0] // x2.shape[0] if diff != 1: x2 = torch.repeat_interleave(x2, diff, dim=0) text_padding_position = torch.repeat_interleave(text_padding_position, diff, dim=0) x = torch.cat([x1, x2], dim=1) if text_padding_position is not None: encoder_padding_mask = torch.cat( [ torch.zeros(x1.shape[:-1], device=x1.device, dtype=torch.bool), text_padding_position, ], dim=1, ) else: encoder_padding_mask = None encoder_out = self.encoder( src_tokens=None, encoder_padding_mask=encoder_padding_mask, attn_mask=attn_mask, token_embeddings=x, multiway_split_position=multiway_split_position, incremental_state=incremental_state, positions=positions, ) return encoder_out