Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import copy | |
from torch.nn import MultiheadAttention | |
from motion.model.layer_norm_fp16 import LayerNorm, RMSNorm | |
import numpy as np | |
import math | |
class SwiGLU(nn.Module): | |
''' | |
follow the structure of llama | |
''' | |
def __init__(self, dim, hidden_dim, multiple_of = 256): | |
super().__init__() | |
hidden_dim = int(2 * hidden_dim / 3) | |
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) | |
self.w1 = nn.Linear(dim, hidden_dim, bias=False) | |
self.w2 = nn.Linear(hidden_dim, dim, bias=False) | |
self.w3 = nn.Linear(dim, hidden_dim, bias= False) | |
def forward(self, x): | |
return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
def _get_activation_fn(activation: str): | |
if activation.lower() == "relu": | |
return F.relu | |
elif activation.lower() == "gelu": | |
return F.gelu | |
raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) | |
def _get_clones(module, N): | |
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |
class RefinedLayer(nn.Module): | |
__constants__ = ['batch_first', 'norm_first'] | |
def __init__(self, d_model, nhead, dim_feedforward = 2048, dropout = 0.1, | |
activation = F.relu, layer_norm_eps = 1e-5, device=None, dtype=None, max_seq_len=196, position_type="static", word_tokens=False, norm_type="rmsnorm", attention_type="torch"): | |
factory_kwargs = {'device': device, 'dtype': dtype, "bias":False} | |
super().__init__() | |
if norm_type.lower() == "rmsnorm": | |
Norm = RMSNorm | |
elif norm_type.lower() == "layer": | |
Norm = LayerNorm | |
self.attention_type = attention_type | |
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False, **factory_kwargs) | |
if word_tokens: | |
self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False, **factory_kwargs) | |
self.norm3 = Norm(d_model, layer_norm_eps) | |
self.dropout3 = nn.Dropout(dropout) | |
self.word_tokens = word_tokens | |
# Implementation of Feedforward model | |
self.norm1 = Norm(d_model, layer_norm_eps) | |
self.norm2 = Norm(d_model, layer_norm_eps) | |
self.dropout1 = nn.Dropout(dropout) | |
self.dropout2 = nn.Dropout(dropout) | |
# Legacy string support for activation function. | |
if isinstance(activation, str) and activation.lower() != "swiglu": | |
activation = _get_activation_fn(activation) | |
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) | |
self.dropout = nn.Dropout(dropout) | |
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) | |
self.ffn = self._ff_block | |
elif activation.lower() == "swiglu": | |
self.ffn = SwiGLU(d_model, dim_feedforward) | |
self.activation = activation | |
def forward( | |
self, | |
src, | |
word_tokens = None, | |
src_mask = None, | |
src_key_padding_mask = None): | |
x = src | |
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) | |
if self.word_tokens: | |
x = x + self._csa_block(self.norm3(x), word_tokens) | |
x = x + self.dropout2(self.ffn(self.norm2(x))) | |
return x | |
# encoder block | |
def _sa_block(self, x, attn_mask, key_padding_mask): | |
x = self.self_attn(x, x, x, | |
attn_mask=attn_mask, | |
key_padding_mask=key_padding_mask, | |
need_weights=False)[0] | |
return self.dropout1(x) | |
# multihead attention block | |
def _csa_block(self, x, mem, attn_mask=None, key_padding_mask=None): | |
x = self.cross_attn(x, mem, mem, | |
attn_mask=attn_mask, | |
key_padding_mask=key_padding_mask, | |
need_weights=False)[0] | |
return self.dropout3(x) | |
# feed forward block | |
def _ff_block(self, x): | |
x = self.linear2(self.dropout(self.activation(self.linear1(x)))) | |
return x | |
class Refined_Transformer(nn.Module): | |
def __init__(self, refined_layer, num_layers): | |
super().__init__() | |
self.layers = _get_clones(refined_layer, num_layers) | |
self.num_layers = num_layers | |
def forward( | |
self, | |
src, | |
word_tokens=None, | |
src_mask=None, | |
src_key_padding_mask = None): | |
output = src | |
src_key_padding_mask_for_layers = src_key_padding_mask | |
for mod in self.layers: | |
output = mod(output, word_tokens=word_tokens, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask_for_layers) | |
return output | |