|
|
|
|
|
|
|
import math |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from fairscale.nn import checkpoint_wrapper, wrap |
|
|
|
try: |
|
from apex.normalization import FusedLayerNorm as LayerNorm |
|
except ModuleNotFoundError: |
|
from torch.nn import LayerNorm |
|
|
|
from vlmo.torchscale.architecture.utils import init_bert_params |
|
from vlmo.torchscale.component.droppath import DropPath |
|
from vlmo.torchscale.component.feedforward_network import FeedForwardNetwork, make_experts |
|
from vlmo.torchscale.component.multihead_attention import MultiheadAttention |
|
from vlmo.torchscale.component.multiway_network import MultiwayWrapper, set_split_position |
|
from vlmo.torchscale.component.relative_position_bias import RelativePositionBias |
|
|
|
|
|
|
|
from pytorch_lightning.utilities.rank_zero import rank_zero_info |
|
|
|
def no_sync_module_apply(module, fn): |
|
"""FSDP module .apply will use _unshard_params_recurse which will sync params across ranks. |
|
using this function when apply fn is unnecessary to sync params across ranks. |
|
""" |
|
for child in module.children(): |
|
fn(child) |
|
no_sync_module_apply(child, fn) |
|
|
|
class EncoderLayer(nn.Module): |
|
def __init__(self, args, depth, attn=None, is_moe_layer=False, is_encoder_decoder=False): |
|
super().__init__() |
|
self.args = args |
|
self.embed_dim = args.encoder_embed_dim |
|
self.self_attn = self.build_self_attention(self.embed_dim, args) if attn is None else attn |
|
self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps)) |
|
self.dropout_module = torch.nn.Dropout(args.dropout) |
|
|
|
if args.drop_path_rate > 0: |
|
drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[depth] |
|
self.drop_path = DropPath(drop_path_prob) |
|
else: |
|
self.drop_path = None |
|
|
|
self.normalize_before = args.encoder_normalize_before |
|
self.is_moe_layer = is_moe_layer |
|
self.ffn_dim = args.encoder_ffn_embed_dim |
|
|
|
if not self.is_moe_layer: |
|
self.ffn = MultiwayWrapper( |
|
args, |
|
self.build_ffn( |
|
self.embed_dim, |
|
self.args, |
|
), |
|
) |
|
else: |
|
assert not self.args.multiway |
|
if args.moe_top1_expert: |
|
gate = Top1Gate( |
|
self.embed_dim, |
|
args.moe_expert_count, |
|
use_fp32=args.moe_gating_use_fp32, |
|
moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction, |
|
use_xmoe=args.use_xmoe, |
|
) |
|
else: |
|
gate = Top2Gate( |
|
self.embed_dim, |
|
args.moe_expert_count, |
|
args.moe_gating_use_fp32, |
|
args.moe_second_expert_policy, |
|
args.moe_normalize_gate_prob_before_dropping, |
|
args.moe_eval_capacity_token_fraction, |
|
use_xmoe=args.use_xmoe, |
|
) |
|
experts = make_experts(args, self.embed_dim, self.ffn_dim) |
|
self.moe_layer = MOELayer(gate, experts, args) |
|
self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps)) |
|
|
|
if args.deepnorm: |
|
if is_encoder_decoder: |
|
self.alpha = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) * 0.81 |
|
else: |
|
self.alpha = math.pow(2.0 * args.encoder_layers, 0.25) |
|
else: |
|
self.alpha = 1.0 |
|
|
|
def build_ffn(self, embed_dim, args): |
|
return FeedForwardNetwork( |
|
embed_dim, |
|
self.ffn_dim, |
|
args.activation_fn, |
|
args.dropout, |
|
args.activation_dropout, |
|
args.layernorm_eps, |
|
args.subln, |
|
) |
|
|
|
def build_self_attention(self, embed_dim, args): |
|
return MultiheadAttention( |
|
args, |
|
embed_dim, |
|
args.encoder_attention_heads, |
|
dropout=args.attention_dropout, |
|
self_attention=True, |
|
encoder_decoder_attention=False, |
|
subln=args.subln, |
|
one_attn=args.one_attn, |
|
) |
|
|
|
def residual_connection(self, x, residual): |
|
return residual * self.alpha + x |
|
|
|
def forward( |
|
self, |
|
x, |
|
encoder_padding_mask, |
|
attn_mask=None, |
|
rel_pos=None, |
|
multiway_split_position=None, |
|
incremental_state=None, |
|
): |
|
if multiway_split_position is not None: |
|
assert self.args.multiway |
|
no_sync_module_apply(self, set_split_position(multiway_split_position)) |
|
|
|
if attn_mask is not None: |
|
|
|
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) |
|
|
|
residual = x |
|
if self.normalize_before: |
|
x = self.self_attn_layer_norm(x) |
|
x, _ = self.self_attn( |
|
query=x, |
|
key=x, |
|
value=x, |
|
key_padding_mask=encoder_padding_mask, |
|
attn_mask=attn_mask, |
|
rel_pos=rel_pos, |
|
incremental_state=incremental_state, |
|
) |
|
x = self.dropout_module(x) |
|
|
|
if self.drop_path is not None: |
|
x = self.drop_path(x) |
|
|
|
x = self.residual_connection(x, residual) |
|
if not self.normalize_before: |
|
x = self.self_attn_layer_norm(x) |
|
|
|
residual = x |
|
if self.normalize_before: |
|
x = self.final_layer_norm(x) |
|
if not self.is_moe_layer: |
|
x = self.ffn(x) |
|
l_aux = None |
|
else: |
|
x = x.transpose(0, 1) |
|
x, l_aux = self.moe_layer(x) |
|
x = x.transpose(0, 1) |
|
|
|
if self.drop_path is not None: |
|
x = self.drop_path(x) |
|
|
|
x = self.residual_connection(x, residual) |
|
if not self.normalize_before: |
|
x = self.final_layer_norm(x) |
|
return x, l_aux |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__( |
|
self, args, embed_tokens=None, embed_positions=None, output_projection=None, is_encoder_decoder=False, **kwargs |
|
): |
|
self.args = args |
|
super().__init__(**kwargs) |
|
|
|
self.dropout_module = torch.nn.Dropout(args.dropout) |
|
|
|
embed_dim = args.encoder_embed_dim |
|
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) |
|
self.mask_ratio = args.mask_ratio |
|
self.max_text_len = args.max_text_len |
|
self.vision_len = (args.img_size // args.patch_size) * (args.img_size // args.patch_size) |
|
|
|
self.embed_tokens = embed_tokens |
|
self.embed_positions = embed_positions |
|
|
|
if output_projection is None and not is_encoder_decoder and not args.no_output_layer and args.vocab_size > 0: |
|
self.output_projection = self.build_output_projection(args) |
|
else: |
|
self.output_projection = output_projection |
|
|
|
if args.layernorm_embedding: |
|
self.layernorm_embedding = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps), dim=1) |
|
else: |
|
self.layernorm_embedding = None |
|
|
|
self.layers = nn.ModuleList([]) |
|
if self.args.share_layer: |
|
single_layer = self.build_encoder_layer( |
|
args, depth=0, is_moe_layer=False, is_encoder_decoder=is_encoder_decoder |
|
) |
|
for i in range(args.encoder_layers): |
|
self.layers.append(single_layer) |
|
elif self.args.share_attn: |
|
moe_freq = args.moe_freq |
|
embed_dim = args.encoder_embed_dim |
|
shared_attn = self.build_self_attention(embed_dim, self.args) |
|
for i in range(args.encoder_layers): |
|
is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0 |
|
self.layers.append( |
|
self.build_encoder_layer( |
|
args, |
|
depth=i, |
|
attn=shared_attn, |
|
is_moe_layer=is_moe_layer, |
|
is_encoder_decoder=is_encoder_decoder, |
|
) |
|
) |
|
|
|
else: |
|
moe_freq = args.moe_freq |
|
for i in range(args.encoder_layers): |
|
is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0 |
|
self.layers.append( |
|
self.build_encoder_layer( |
|
args, |
|
depth=i, |
|
is_moe_layer=is_moe_layer, |
|
is_encoder_decoder=is_encoder_decoder, |
|
) |
|
) |
|
self.num_layers = len(self.layers) |
|
|
|
if args.encoder_normalize_before and args.normalize_output: |
|
self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps)) |
|
else: |
|
self.layer_norm = None |
|
|
|
if args.rel_pos_buckets > 0 and args.max_rel_pos > 0: |
|
self.relative_position = RelativePositionBias( |
|
num_buckets=args.rel_pos_buckets, |
|
max_distance=args.max_rel_pos, |
|
n_heads=args.encoder_attention_heads, |
|
) |
|
else: |
|
self.relative_position = None |
|
|
|
if args.bert_init: |
|
self.apply(init_bert_params) |
|
|
|
if args.deepnorm: |
|
if is_encoder_decoder: |
|
init_scale = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) / 1.15 |
|
else: |
|
init_scale = math.pow(8.0 * args.encoder_layers, 0.25) |
|
for name, p in self.named_parameters(): |
|
if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name: |
|
p.data.div_(init_scale) |
|
|
|
if args.subln: |
|
if is_encoder_decoder: |
|
init_scale = math.sqrt(math.log(3 * args.decoder_layers) * math.log(2 * args.encoder_layers) / 3) |
|
else: |
|
init_scale = math.sqrt(math.log(args.encoder_layers * 2)) |
|
for name, p in self.named_parameters(): |
|
if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name: |
|
p.data.mul_(init_scale) |
|
|
|
def random_masking(self, x, mask_ratio): |
|
N, L, D = x.shape |
|
len_keep = int(L * (1 - mask_ratio)) |
|
|
|
noise = torch.rand(N, L - 1, device=x.device) |
|
ids_shuffle = torch.argsort(noise, dim=1) + torch.ones(N, L - 1, device=x.device, dtype=int) |
|
ids_keep = ids_shuffle[:, :len_keep] |
|
|
|
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
x0 = x[:, 0, :] |
|
x0 = x0.reshape(N, 1, D) |
|
x_masked_add = torch.cat([x0, x_masked], axis=1) |
|
return x_masked_add, ids_keep |
|
|
|
def build_self_attention(self, embed_dim, args): |
|
return MultiheadAttention( |
|
args, |
|
embed_dim, |
|
args.encoder_attention_heads, |
|
dropout=args.attention_dropout, |
|
self_attention=True, |
|
encoder_decoder_attention=False, |
|
subln=args.subln, |
|
one_attn=args.one_attn, |
|
) |
|
|
|
def build_output_projection( |
|
self, |
|
args, |
|
): |
|
if args.share_encoder_input_output_embed: |
|
assert args.encoder_embedding_type == "language" |
|
output_projection = torch.nn.Linear( |
|
self.embed_tokens.weight.shape[1], |
|
self.embed_tokens.weight.shape[0], |
|
bias=False, |
|
) |
|
output_projection.weight = self.embed_tokens.weight |
|
else: |
|
output_projection = torch.nn.Linear(args.encoder_embed_dim, args.vocab_size, bias=False) |
|
torch.nn.init.normal_(output_projection.weight, mean=0, std=args.encoder_embed_dim**-0.5) |
|
return output_projection |
|
|
|
def checkpointing_and_params_allgather( |
|
self, |
|
origin_layer, |
|
): |
|
origin_forward = origin_layer.forward |
|
|
|
from deepspeed import checkpointing |
|
def forward(*args, **kwargs): |
|
|
|
ret = checkpointing.checkpoint(origin_forward, *args, **kwargs) |
|
return ret |
|
|
|
return forward |
|
|
|
def build_encoder_layer(self, args, depth, attn=None, is_moe_layer=False, is_encoder_decoder=False): |
|
layer = EncoderLayer( |
|
args, |
|
depth, |
|
attn, |
|
is_moe_layer=is_moe_layer, |
|
is_encoder_decoder=is_encoder_decoder, |
|
) |
|
if args.checkpoint_activations: |
|
rank_zero_info("EncoderLayer params: %s", sum(p.numel() for p in layer.parameters() if p.requires_grad)) |
|
layer = checkpoint_wrapper(layer) |
|
|
|
if args.fsdp: |
|
layer = wrap(layer) |
|
return layer |
|
|
|
def checkpointing_layers(self): |
|
for i, layer in enumerate(self.layers): |
|
rank_zero_info(f"Checkpointing wrapper EncoderLayers: {i}") |
|
self.layers[i] = checkpoint_wrapper(layer) |
|
|
|
def forward_embedding( |
|
self, |
|
src_tokens, |
|
token_embedding=None, |
|
positions=None, |
|
): |
|
if token_embedding is None: |
|
token_embedding = self.embed_tokens(src_tokens) |
|
x = embed = self.embed_scale * token_embedding |
|
if self.embed_positions is not None: |
|
if src_tokens is not None: |
|
x = embed + self.embed_positions(src_tokens, positions=positions) |
|
else: |
|
x = embed + self.embed_positions(x, positions=positions) |
|
is_flip, ids_keep = 0, None |
|
if self.mask_ratio > 0: |
|
if x.shape[1] == self.vision_len + 1: |
|
x, ids_keep = self.random_masking(x, self.mask_ratio) |
|
is_flip = 1 |
|
elif x.shape[1] == self.vision_len + self.max_text_len + 1: |
|
vision_tokens = x[:, : self.vision_len + 1, :] |
|
vision_tokens, ids_keep = self.random_masking(vision_tokens, self.mask_ratio) |
|
x = torch.cat( |
|
[ |
|
vision_tokens, |
|
x[ |
|
:, |
|
self.vision_len + 1 :, |
|
], |
|
], |
|
dim=1, |
|
) |
|
is_flip = 2 |
|
if self.layernorm_embedding is not None: |
|
x = self.layernorm_embedding(x) |
|
x = self.dropout_module(x) |
|
return x, embed, ids_keep, is_flip |
|
|
|
def forward( |
|
self, |
|
src_tokens, |
|
encoder_padding_mask=None, |
|
attn_mask=None, |
|
return_all_hiddens=False, |
|
token_embeddings=None, |
|
multiway_split_position=None, |
|
features_only=False, |
|
incremental_state=None, |
|
positions=None, |
|
**kwargs |
|
): |
|
assert src_tokens is not None or token_embeddings is not None |
|
|
|
if encoder_padding_mask is None: |
|
if src_tokens is not None: |
|
encoder_padding_mask = torch.zeros_like(src_tokens, device=src_tokens.device).bool() |
|
else: |
|
encoder_padding_mask = torch.zeros( |
|
[token_embeddings.size(0), token_embeddings.size(1)], |
|
device=token_embeddings.device, |
|
).bool() |
|
|
|
if multiway_split_position is not None: |
|
assert self.args.multiway |
|
no_sync_module_apply(self, set_split_position(multiway_split_position)) |
|
|
|
x, encoder_embedding, ids_keep, is_flip = self.forward_embedding(src_tokens, token_embeddings, positions) |
|
if is_flip > 0: |
|
if is_flip == 2: |
|
text_ids = ( |
|
torch.arange( |
|
self.vision_len + 1, self.vision_len + 1 + self.max_text_len, device=x.device, dtype=torch.int64 |
|
) |
|
.unsqueeze(0) |
|
.repeat(ids_keep.shape[0], 1) |
|
) |
|
cls_ids = torch.zeros(ids_keep.shape[0], 1, device=x.device, dtype=torch.int64) |
|
ids_keep = torch.cat([cls_ids, ids_keep, text_ids], dim=1) |
|
elif is_flip == 1: |
|
cls_ids = torch.zeros(ids_keep.shape[0], 1, device=x.device, dtype=torch.int64) |
|
ids_keep = torch.cat([cls_ids, ids_keep], dim=1) |
|
if encoder_padding_mask is not None: |
|
encoder_padding_mask = torch.gather(encoder_padding_mask, dim=1, index=ids_keep) |
|
if attn_mask is not None: |
|
attn_mask = torch.gather( |
|
attn_mask, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, attn_mask.shape[-1]) |
|
) |
|
attn_mask = torch.gather(attn_mask, dim=2, index=ids_keep.unsqueeze(1).repeat(1, attn_mask.shape[1], 1)) |
|
if multiway_split_position > 0: |
|
multiway_split_position = ids_keep.shape[1] - self.max_text_len |
|
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) |
|
|
|
encoder_states = [] |
|
|
|
if return_all_hiddens: |
|
encoder_states.append(x) |
|
|
|
rel_pos_bias = None |
|
if self.relative_position is not None: |
|
rel_pos_bias = self.relative_position(batch_size=x.size(0), qlen=x.size(1), klen=x.size(1)) |
|
|
|
l_aux = [] |
|
for idx, layer in enumerate(self.layers): |
|
x, l_aux_i = layer( |
|
x, |
|
encoder_padding_mask=encoder_padding_mask if incremental_state is None else None, |
|
attn_mask=attn_mask, |
|
rel_pos=rel_pos_bias, |
|
multiway_split_position=multiway_split_position, |
|
incremental_state=incremental_state[idx] if incremental_state is not None else None, |
|
) |
|
if return_all_hiddens: |
|
assert encoder_states is not None |
|
encoder_states.append(x) |
|
l_aux.append(l_aux_i) |
|
|
|
if multiway_split_position is not None: |
|
assert self.args.multiway |
|
no_sync_module_apply(self, set_split_position(multiway_split_position)) |
|
if self.layer_norm is not None: |
|
x = self.layer_norm(x) |
|
|
|
if not features_only and self.output_projection is not None: |
|
x = self.output_projection(x) |
|
|
|
return { |
|
"encoder_out": x, |
|
"encoder_embedding": encoder_embedding, |
|
"encoder_padding_mask": encoder_padding_mask, |
|
"encoder_states": encoder_states, |
|
"l_aux": l_aux, |
|
"multiway_split_position": multiway_split_position, |
|
} |
|
|