# Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] import math import numpy as np import torch import torch.nn as nn from fairscale.nn import checkpoint_wrapper, wrap try: from apex.normalization import FusedLayerNorm as LayerNorm except ModuleNotFoundError: from torch.nn import LayerNorm from vlmo.torchscale.architecture.utils import init_bert_params from vlmo.torchscale.component.droppath import DropPath from vlmo.torchscale.component.feedforward_network import FeedForwardNetwork, make_experts from vlmo.torchscale.component.multihead_attention import MultiheadAttention from vlmo.torchscale.component.multiway_network import MultiwayWrapper, set_split_position from vlmo.torchscale.component.relative_position_bias import RelativePositionBias #from vlmo.torchscale.component.xmoe.moe_layer import MOELayer #from vlmo.torchscale.component.xmoe.routing import Top1Gate, Top2Gate # from vlmo.modules.vlmo_utils import no_sync_module_apply from pytorch_lightning.utilities.rank_zero import rank_zero_info def no_sync_module_apply(module, fn): """FSDP module .apply will use _unshard_params_recurse which will sync params across ranks. using this function when apply fn is unnecessary to sync params across ranks. """ for child in module.children(): fn(child) no_sync_module_apply(child, fn) class EncoderLayer(nn.Module): def __init__(self, args, depth, attn=None, is_moe_layer=False, is_encoder_decoder=False): super().__init__() self.args = args self.embed_dim = args.encoder_embed_dim self.self_attn = self.build_self_attention(self.embed_dim, args) if attn is None else attn self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps)) self.dropout_module = torch.nn.Dropout(args.dropout) if args.drop_path_rate > 0: drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[depth] self.drop_path = DropPath(drop_path_prob) else: self.drop_path = None self.normalize_before = args.encoder_normalize_before self.is_moe_layer = is_moe_layer self.ffn_dim = args.encoder_ffn_embed_dim if not self.is_moe_layer: self.ffn = MultiwayWrapper( args, self.build_ffn( self.embed_dim, self.args, ), ) else: assert not self.args.multiway if args.moe_top1_expert: gate = Top1Gate( self.embed_dim, args.moe_expert_count, use_fp32=args.moe_gating_use_fp32, moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction, use_xmoe=args.use_xmoe, ) else: gate = Top2Gate( self.embed_dim, args.moe_expert_count, args.moe_gating_use_fp32, args.moe_second_expert_policy, args.moe_normalize_gate_prob_before_dropping, args.moe_eval_capacity_token_fraction, use_xmoe=args.use_xmoe, ) experts = make_experts(args, self.embed_dim, self.ffn_dim) self.moe_layer = MOELayer(gate, experts, args) self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps)) if args.deepnorm: if is_encoder_decoder: self.alpha = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) * 0.81 else: self.alpha = math.pow(2.0 * args.encoder_layers, 0.25) else: self.alpha = 1.0 def build_ffn(self, embed_dim, args): return FeedForwardNetwork( embed_dim, self.ffn_dim, args.activation_fn, args.dropout, args.activation_dropout, args.layernorm_eps, args.subln, ) def build_self_attention(self, embed_dim, args): return MultiheadAttention( args, embed_dim, args.encoder_attention_heads, dropout=args.attention_dropout, self_attention=True, encoder_decoder_attention=False, subln=args.subln, one_attn=args.one_attn, ) def residual_connection(self, x, residual): return residual * self.alpha + x def forward( self, x, encoder_padding_mask, attn_mask=None, rel_pos=None, multiway_split_position=None, incremental_state=None, ): if multiway_split_position is not None: assert self.args.multiway no_sync_module_apply(self, set_split_position(multiway_split_position)) if attn_mask is not None: # float16: -1e8 equal 0 attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) residual = x if self.normalize_before: x = self.self_attn_layer_norm(x) x, _ = self.self_attn( query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, attn_mask=attn_mask, rel_pos=rel_pos, incremental_state=incremental_state, ) x = self.dropout_module(x) if self.drop_path is not None: x = self.drop_path(x) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.self_attn_layer_norm(x) residual = x if self.normalize_before: x = self.final_layer_norm(x) if not self.is_moe_layer: x = self.ffn(x) l_aux = None else: x = x.transpose(0, 1) x, l_aux = self.moe_layer(x) x = x.transpose(0, 1) if self.drop_path is not None: x = self.drop_path(x) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.final_layer_norm(x) return x, l_aux class Encoder(nn.Module): def __init__( self, args, embed_tokens=None, embed_positions=None, output_projection=None, is_encoder_decoder=False, **kwargs ): self.args = args super().__init__(**kwargs) self.dropout_module = torch.nn.Dropout(args.dropout) embed_dim = args.encoder_embed_dim self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) self.mask_ratio = args.mask_ratio self.max_text_len = args.max_text_len self.vision_len = (args.img_size // args.patch_size) * (args.img_size // args.patch_size) self.embed_tokens = embed_tokens self.embed_positions = embed_positions if output_projection is None and not is_encoder_decoder and not args.no_output_layer and args.vocab_size > 0: self.output_projection = self.build_output_projection(args) else: self.output_projection = output_projection if args.layernorm_embedding: self.layernorm_embedding = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps), dim=1) else: self.layernorm_embedding = None self.layers = nn.ModuleList([]) if self.args.share_layer: single_layer = self.build_encoder_layer( args, depth=0, is_moe_layer=False, is_encoder_decoder=is_encoder_decoder ) for i in range(args.encoder_layers): self.layers.append(single_layer) elif self.args.share_attn: moe_freq = args.moe_freq embed_dim = args.encoder_embed_dim shared_attn = self.build_self_attention(embed_dim, self.args) for i in range(args.encoder_layers): is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0 self.layers.append( self.build_encoder_layer( args, depth=i, attn=shared_attn, is_moe_layer=is_moe_layer, is_encoder_decoder=is_encoder_decoder, ) ) else: moe_freq = args.moe_freq for i in range(args.encoder_layers): is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0 self.layers.append( self.build_encoder_layer( args, depth=i, is_moe_layer=is_moe_layer, is_encoder_decoder=is_encoder_decoder, ) ) self.num_layers = len(self.layers) if args.encoder_normalize_before and args.normalize_output: self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps)) else: self.layer_norm = None if args.rel_pos_buckets > 0 and args.max_rel_pos > 0: self.relative_position = RelativePositionBias( num_buckets=args.rel_pos_buckets, max_distance=args.max_rel_pos, n_heads=args.encoder_attention_heads, ) else: self.relative_position = None if args.bert_init: self.apply(init_bert_params) if args.deepnorm: if is_encoder_decoder: init_scale = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) / 1.15 else: init_scale = math.pow(8.0 * args.encoder_layers, 0.25) for name, p in self.named_parameters(): if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name: p.data.div_(init_scale) if args.subln: if is_encoder_decoder: init_scale = math.sqrt(math.log(3 * args.decoder_layers) * math.log(2 * args.encoder_layers) / 3) else: init_scale = math.sqrt(math.log(args.encoder_layers * 2)) for name, p in self.named_parameters(): if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name: p.data.mul_(init_scale) def random_masking(self, x, mask_ratio): N, L, D = x.shape # batch, length, dim len_keep = int(L * (1 - mask_ratio)) noise = torch.rand(N, L - 1, device=x.device) ids_shuffle = torch.argsort(noise, dim=1) + torch.ones(N, L - 1, device=x.device, dtype=int) ids_keep = ids_shuffle[:, :len_keep] x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) x0 = x[:, 0, :] x0 = x0.reshape(N, 1, D) x_masked_add = torch.cat([x0, x_masked], axis=1) return x_masked_add, ids_keep def build_self_attention(self, embed_dim, args): return MultiheadAttention( args, embed_dim, args.encoder_attention_heads, dropout=args.attention_dropout, self_attention=True, encoder_decoder_attention=False, subln=args.subln, one_attn=args.one_attn, ) def build_output_projection( self, args, ): if args.share_encoder_input_output_embed: assert args.encoder_embedding_type == "language" output_projection = torch.nn.Linear( self.embed_tokens.weight.shape[1], self.embed_tokens.weight.shape[0], bias=False, ) output_projection.weight = self.embed_tokens.weight else: output_projection = torch.nn.Linear(args.encoder_embed_dim, args.vocab_size, bias=False) torch.nn.init.normal_(output_projection.weight, mean=0, std=args.encoder_embed_dim**-0.5) return output_projection def checkpointing_and_params_allgather( self, origin_layer, ): origin_forward = origin_layer.forward from deepspeed import checkpointing def forward(*args, **kwargs): # deepspeed checkpoint not support kwargs ret = checkpointing.checkpoint(origin_forward, *args, **kwargs) return ret return forward def build_encoder_layer(self, args, depth, attn=None, is_moe_layer=False, is_encoder_decoder=False): layer = EncoderLayer( args, depth, attn, is_moe_layer=is_moe_layer, is_encoder_decoder=is_encoder_decoder, ) if args.checkpoint_activations: rank_zero_info("EncoderLayer params: %s", sum(p.numel() for p in layer.parameters() if p.requires_grad)) layer = checkpoint_wrapper(layer) # layer.ffn = checkpoint_wrapper(layer.ffn,) if args.fsdp: layer = wrap(layer) return layer def checkpointing_layers(self): for i, layer in enumerate(self.layers): rank_zero_info(f"Checkpointing wrapper EncoderLayers: {i}") self.layers[i] = checkpoint_wrapper(layer) def forward_embedding( self, src_tokens, token_embedding=None, positions=None, ): if token_embedding is None: token_embedding = self.embed_tokens(src_tokens) x = embed = self.embed_scale * token_embedding if self.embed_positions is not None: if src_tokens is not None: x = embed + self.embed_positions(src_tokens, positions=positions) else: x = embed + self.embed_positions(x, positions=positions) is_flip, ids_keep = 0, None if self.mask_ratio > 0: if x.shape[1] == self.vision_len + 1: x, ids_keep = self.random_masking(x, self.mask_ratio) is_flip = 1 elif x.shape[1] == self.vision_len + self.max_text_len + 1: vision_tokens = x[:, : self.vision_len + 1, :] vision_tokens, ids_keep = self.random_masking(vision_tokens, self.mask_ratio) x = torch.cat( [ vision_tokens, x[ :, self.vision_len + 1 :, ], ], dim=1, ) is_flip = 2 if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) return x, embed, ids_keep, is_flip def forward( self, src_tokens, encoder_padding_mask=None, attn_mask=None, return_all_hiddens=False, token_embeddings=None, multiway_split_position=None, features_only=False, incremental_state=None, positions=None, **kwargs ): assert src_tokens is not None or token_embeddings is not None if encoder_padding_mask is None: if src_tokens is not None: encoder_padding_mask = torch.zeros_like(src_tokens, device=src_tokens.device).bool() else: encoder_padding_mask = torch.zeros( [token_embeddings.size(0), token_embeddings.size(1)], device=token_embeddings.device, ).bool() if multiway_split_position is not None: assert self.args.multiway no_sync_module_apply(self, set_split_position(multiway_split_position)) x, encoder_embedding, ids_keep, is_flip = self.forward_embedding(src_tokens, token_embeddings, positions) if is_flip > 0: if is_flip == 2: text_ids = ( torch.arange( self.vision_len + 1, self.vision_len + 1 + self.max_text_len, device=x.device, dtype=torch.int64 ) .unsqueeze(0) .repeat(ids_keep.shape[0], 1) ) cls_ids = torch.zeros(ids_keep.shape[0], 1, device=x.device, dtype=torch.int64) ids_keep = torch.cat([cls_ids, ids_keep, text_ids], dim=1) elif is_flip == 1: cls_ids = torch.zeros(ids_keep.shape[0], 1, device=x.device, dtype=torch.int64) ids_keep = torch.cat([cls_ids, ids_keep], dim=1) if encoder_padding_mask is not None: encoder_padding_mask = torch.gather(encoder_padding_mask, dim=1, index=ids_keep) if attn_mask is not None: attn_mask = torch.gather( attn_mask, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, attn_mask.shape[-1]) ) attn_mask = torch.gather(attn_mask, dim=2, index=ids_keep.unsqueeze(1).repeat(1, attn_mask.shape[1], 1)) if multiway_split_position > 0: multiway_split_position = ids_keep.shape[1] - self.max_text_len x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) encoder_states = [] if return_all_hiddens: encoder_states.append(x) rel_pos_bias = None if self.relative_position is not None: rel_pos_bias = self.relative_position(batch_size=x.size(0), qlen=x.size(1), klen=x.size(1)) l_aux = [] for idx, layer in enumerate(self.layers): x, l_aux_i = layer( x, encoder_padding_mask=encoder_padding_mask if incremental_state is None else None, attn_mask=attn_mask, rel_pos=rel_pos_bias, multiway_split_position=multiway_split_position, incremental_state=incremental_state[idx] if incremental_state is not None else None, ) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) l_aux.append(l_aux_i) if multiway_split_position is not None: assert self.args.multiway no_sync_module_apply(self, set_split_position(multiway_split_position)) if self.layer_norm is not None: x = self.layer_norm(x) if not features_only and self.output_projection is not None: x = self.output_projection(x) return { "encoder_out": x, "encoder_embedding": encoder_embedding, "encoder_padding_mask": encoder_padding_mask, "encoder_states": encoder_states, "l_aux": l_aux, "multiway_split_position": multiway_split_position, }