import torch from torch import nn import torch.nn.functional as F import copy from torch.nn import MultiheadAttention from motion.model.layer_norm_fp16 import LayerNorm, RMSNorm import numpy as np import math class SwiGLU(nn.Module): ''' follow the structure of llama ''' def __init__(self, dim, hidden_dim, multiple_of = 256): super().__init__() hidden_dim = int(2 * hidden_dim / 3) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias= False) def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) def _get_activation_fn(activation: str): if activation.lower() == "relu": return F.relu elif activation.lower() == "gelu": return F.gelu raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) class RefinedLayer(nn.Module): __constants__ = ['batch_first', 'norm_first'] def __init__(self, d_model, nhead, dim_feedforward = 2048, dropout = 0.1, activation = F.relu, layer_norm_eps = 1e-5, device=None, dtype=None, max_seq_len=196, position_type="static", word_tokens=False, norm_type="rmsnorm", attention_type="torch"): factory_kwargs = {'device': device, 'dtype': dtype, "bias":False} super().__init__() if norm_type.lower() == "rmsnorm": Norm = RMSNorm elif norm_type.lower() == "layer": Norm = LayerNorm self.attention_type = attention_type self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False, **factory_kwargs) if word_tokens: self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False, **factory_kwargs) self.norm3 = Norm(d_model, layer_norm_eps) self.dropout3 = nn.Dropout(dropout) self.word_tokens = word_tokens # Implementation of Feedforward model self.norm1 = Norm(d_model, layer_norm_eps) self.norm2 = Norm(d_model, layer_norm_eps) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) # Legacy string support for activation function. if isinstance(activation, str) and activation.lower() != "swiglu": activation = _get_activation_fn(activation) self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.ffn = self._ff_block elif activation.lower() == "swiglu": self.ffn = SwiGLU(d_model, dim_feedforward) self.activation = activation def forward( self, src, word_tokens = None, src_mask = None, src_key_padding_mask = None): x = src x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) if self.word_tokens: x = x + self._csa_block(self.norm3(x), word_tokens) x = x + self.dropout2(self.ffn(self.norm2(x))) return x # encoder block def _sa_block(self, x, attn_mask, key_padding_mask): x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0] return self.dropout1(x) # multihead attention block def _csa_block(self, x, mem, attn_mask=None, key_padding_mask=None): x = self.cross_attn(x, mem, mem, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0] return self.dropout3(x) # feed forward block def _ff_block(self, x): x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return x class Refined_Transformer(nn.Module): def __init__(self, refined_layer, num_layers): super().__init__() self.layers = _get_clones(refined_layer, num_layers) self.num_layers = num_layers def forward( self, src, word_tokens=None, src_mask=None, src_key_padding_mask = None): output = src src_key_padding_mask_for_layers = src_key_padding_mask for mod in self.layers: output = mod(output, word_tokens=word_tokens, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask_for_layers) return output