|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from vlmo.torchscale.architecture.encoder import Encoder |
|
from vlmo.torchscale.component.embedding import ( |
|
PositionalEmbedding, |
|
TextEmbedding, |
|
VisionEmbedding, |
|
) |
|
from vlmo.torchscale.component.multiway_network import MutliwayEmbedding |
|
|
|
|
|
class BEiT3(nn.Module): |
|
def __init__(self, args, **kwargs): |
|
super().__init__() |
|
self.args = args |
|
assert args.multiway |
|
assert args.vocab_size > 0 |
|
assert not args.share_encoder_input_output_embed |
|
self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim) |
|
self.vision_embed = VisionEmbedding( |
|
args.img_size, |
|
args.patch_size, |
|
args.in_chans, |
|
args.encoder_embed_dim, |
|
contain_mask_token=True, |
|
prepend_cls_token=True, |
|
) |
|
|
|
embed_positions = MutliwayEmbedding( |
|
modules=[ |
|
PositionalEmbedding(self.vision_embed.num_position_embeddings() + 2, args.encoder_embed_dim), |
|
PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim), |
|
], |
|
dim=1, |
|
) |
|
self.encoder = Encoder( |
|
args, |
|
embed_tokens=None, |
|
embed_positions=embed_positions, |
|
output_projection=None, |
|
is_encoder_decoder=False, |
|
) |
|
|
|
def forward( |
|
self, |
|
textual_tokens=None, |
|
visual_tokens=None, |
|
text_padding_position=None, |
|
attn_mask=None, |
|
vision_masked_position=None, |
|
incremental_state=None, |
|
positions=None, |
|
): |
|
assert textual_tokens is not None or visual_tokens is not None |
|
|
|
if textual_tokens is None: |
|
x = self.vision_embed(visual_tokens, vision_masked_position) |
|
encoder_padding_mask = None |
|
multiway_split_position = -1 |
|
elif visual_tokens is None: |
|
x = self.text_embed(textual_tokens) |
|
encoder_padding_mask = text_padding_position |
|
multiway_split_position = 0 |
|
else: |
|
x1 = self.vision_embed(visual_tokens, vision_masked_position) |
|
multiway_split_position = x1.size(1) |
|
x2 = self.text_embed(textual_tokens) |
|
diff = x1.shape[0] // x2.shape[0] |
|
if diff != 1: |
|
x2 = torch.repeat_interleave(x2, diff, dim=0) |
|
text_padding_position = torch.repeat_interleave(text_padding_position, diff, dim=0) |
|
x = torch.cat([x1, x2], dim=1) |
|
if text_padding_position is not None: |
|
encoder_padding_mask = torch.cat( |
|
[ |
|
torch.zeros(x1.shape[:-1], device=x1.device, dtype=torch.bool), |
|
text_padding_position, |
|
], |
|
dim=1, |
|
) |
|
else: |
|
encoder_padding_mask = None |
|
encoder_out = self.encoder( |
|
src_tokens=None, |
|
encoder_padding_mask=encoder_padding_mask, |
|
attn_mask=attn_mask, |
|
token_embeddings=x, |
|
multiway_split_position=multiway_split_position, |
|
incremental_state=incremental_state, |
|
positions=positions, |
|
) |
|
return encoder_out |
|
|