acai66's picture
Upload folder using huggingface_hub
3440f83 verified
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import torch
import torch.nn as nn
from vlmo.torchscale.architecture.encoder import Encoder
from vlmo.torchscale.component.embedding import (
PositionalEmbedding,
TextEmbedding,
VisionEmbedding,
)
from vlmo.torchscale.component.multiway_network import MutliwayEmbedding
class BEiT3(nn.Module):
def __init__(self, args, **kwargs):
super().__init__()
self.args = args
assert args.multiway
assert args.vocab_size > 0
assert not args.share_encoder_input_output_embed
self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim)
self.vision_embed = VisionEmbedding(
args.img_size,
args.patch_size,
args.in_chans,
args.encoder_embed_dim,
contain_mask_token=True,
prepend_cls_token=True,
)
# being consistent with Fairseq, which starts from 2 for position embedding
embed_positions = MutliwayEmbedding(
modules=[
PositionalEmbedding(self.vision_embed.num_position_embeddings() + 2, args.encoder_embed_dim),
PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
],
dim=1,
)
self.encoder = Encoder(
args,
embed_tokens=None,
embed_positions=embed_positions,
output_projection=None,
is_encoder_decoder=False,
)
def forward(
self,
textual_tokens=None,
visual_tokens=None,
text_padding_position=None,
attn_mask=None,
vision_masked_position=None,
incremental_state=None,
positions=None,
):
assert textual_tokens is not None or visual_tokens is not None
if textual_tokens is None:
x = self.vision_embed(visual_tokens, vision_masked_position)
encoder_padding_mask = None
multiway_split_position = -1
elif visual_tokens is None:
x = self.text_embed(textual_tokens)
encoder_padding_mask = text_padding_position
multiway_split_position = 0
else:
x1 = self.vision_embed(visual_tokens, vision_masked_position)
multiway_split_position = x1.size(1)
x2 = self.text_embed(textual_tokens)
diff = x1.shape[0] // x2.shape[0]
if diff != 1:
x2 = torch.repeat_interleave(x2, diff, dim=0)
text_padding_position = torch.repeat_interleave(text_padding_position, diff, dim=0)
x = torch.cat([x1, x2], dim=1)
if text_padding_position is not None:
encoder_padding_mask = torch.cat(
[
torch.zeros(x1.shape[:-1], device=x1.device, dtype=torch.bool),
text_padding_position,
],
dim=1,
)
else:
encoder_padding_mask = None
encoder_out = self.encoder(
src_tokens=None,
encoder_padding_mask=encoder_padding_mask,
attn_mask=attn_mask,
token_embeddings=x,
multiway_split_position=multiway_split_position,
incremental_state=incremental_state,
positions=positions,
)
return encoder_out