from __future__ import annotations import math import numpy as np import torch import torch.nn.functional as F import torch.onnx.operators from torch import nn from torch.nn import LayerNorm, MultiheadAttention, ReLU, GELU, SiLU import utils class NormalInitEmbedding(torch.nn.Embedding): def __init__( self, num_embeddings: int, embedding_dim: int, padding_idx: int | None = None, *args, **kwargs ): super().__init__(num_embeddings, embedding_dim, *args, padding_idx=padding_idx, **kwargs) nn.init.normal_(self.weight, mean=0, std=self.embedding_dim ** -0.5) if padding_idx is not None: nn.init.constant_(self.weight[padding_idx], 0) class XavierUniformInitLinear(torch.nn.Linear): def __init__( self, in_features: int, out_features: int, *args, bias: bool = True, **kwargs ): super().__init__(in_features, out_features, *args, bias=bias, **kwargs) nn.init.xavier_uniform_(self.weight) if bias: nn.init.constant_(self.bias, 0.) class SinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length. Padding symbols are ignored. """ def __init__(self, embedding_dim, padding_idx, init_size=1024): super().__init__() self.embedding_dim = embedding_dim self.padding_idx = padding_idx self.weights = SinusoidalPositionalEmbedding.get_embedding( init_size, embedding_dim, padding_idx, ) self.register_buffer('_float_tensor', torch.FloatTensor(1)) @staticmethod def get_embedding(num_embeddings, embedding_dim, padding_idx=None): """Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) if embedding_dim % 2 == 1: # zero pad emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) if padding_idx is not None: emb[padding_idx, :] = 0 return emb def forward(self, x, incremental_state=None, timestep=None, positions=None): """Input is expected to be of size [bsz x seqlen].""" bsz, seq_len = x.shape[:2] max_pos = self.padding_idx + 1 + seq_len if self.weights is None or max_pos > self.weights.size(0): # recompute/expand embeddings if needed self.weights = SinusoidalPositionalEmbedding.get_embedding( max_pos, self.embedding_dim, self.padding_idx, ) self.weights = self.weights.to(self._float_tensor) if incremental_state is not None: # positions is the same for every token when decoding a single step pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) positions = utils.make_positions(x, self.padding_idx) if positions is None else positions return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() @staticmethod def max_positions(): """Maximum number of supported positions.""" return int(1e5) # an arbitrary large number class SwiGLU(nn.Module): # Swish-Applies the gated linear unit function. def __init__(self, dim=-1): super().__init__() self.dim = dim def forward(self, x): # out, gate = x.chunk(2, dim=self.dim) # Using torch.split instead of chunk for ONNX export compatibility. out, gate = torch.split(x, x.size(self.dim) // 2, dim=self.dim) return out * F.silu(gate) class TransformerFFNLayer(nn.Module): def __init__(self, hidden_size, filter_size, kernel_size=1, dropout=0., act='gelu'): super().__init__() self.kernel_size = kernel_size self.dropout = dropout self.act = act filter_size_1 = filter_size if self.act == 'relu': self.act_fn = ReLU() elif self.act == 'gelu': self.act_fn = GELU() elif self.act == 'swish': self.act_fn = SiLU() elif self.act == 'swiglu': self.act_fn = SwiGLU() filter_size_1 = filter_size * 2 else: raise ValueError(f'{act} is not a valid activation') self.ffn_1 = nn.Conv1d(hidden_size, filter_size_1, kernel_size, padding=kernel_size // 2) self.ffn_2 = XavierUniformInitLinear(filter_size, hidden_size) def forward(self, x): # x: B x T x C x = self.ffn_1(x.transpose(1, 2)).transpose(1, 2) x = x * self.kernel_size ** -0.5 x = self.act_fn(x) x = F.dropout(x, self.dropout, training=self.training) x = self.ffn_2(x) return x class MultiheadSelfAttentionWithRoPE(nn.Module): def __init__(self, embed_dim, num_heads, dropout=0.1, bias=False, rotary_embed=None): super().__init__() assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads" self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads # Linear layers for Q, K, V projections self.in_proj = nn.Linear(embed_dim, embed_dim * 3, bias=bias) # Final linear layer after concatenation self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) # Dropout layer self.dropout = nn.Dropout(dropout) # Rotary Embeddings self.rotary_embed = rotary_embed def forward(self, x, key_padding_mask=None): # x: (B, L, C) # key_padding_mask: (B, L) batch_size, seq_len, embed_dim = x.size() # Project inputs to Q, K, V Q, K, V = torch.split(self.in_proj(x), self.embed_dim, dim=-1) # Reshape Q, K, V for multi-head attention Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, L, D) K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, L, D) V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, L, D) # Apply RoPE if self.rotary_embed is not None: Q = self.rotary_embed.rotate_queries_or_keys(Q) K = self.rotary_embed.rotate_queries_or_keys(K) # Compute attention scores scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_dim) # (B, H, L, L) # Apply key padding mask if provided if key_padding_mask is not None: # Expand mask to match attention scores shape mask = key_padding_mask.unsqueeze(1).unsqueeze(1) # (B, 1, 1, L) scores = scores.masked_fill(mask == 1, -np.inf) # Masked positions are set to -inf # Compute attention weights attn_weights = F.softmax(scores, dim=-1) # (B, H, L, L) attn_weights = self.dropout(attn_weights) # Apply attention weights to V attn_output = torch.matmul(attn_weights, V) # (B, H, L, D) # Reshape and concatenate heads attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim) # (B, L, C) # Final linear projection output = self.out_proj(attn_output) # (B, L, C) return output class EncSALayer(nn.Module): def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9, act='gelu', rotary_embed=None): super().__init__() self.dropout = dropout self.layer_norm1 = LayerNorm(c) if rotary_embed is None: self.self_attn = MultiheadAttention( c, num_heads, dropout=attention_dropout, bias=False, batch_first=False ) self.use_rope = False else: self.self_attn = MultiheadSelfAttentionWithRoPE( c, num_heads, dropout=attention_dropout, bias=False, rotary_embed=rotary_embed ) self.use_rope = True self.layer_norm2 = LayerNorm(c) self.ffn = TransformerFFNLayer( c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, act=act ) def forward(self, x, encoder_padding_mask=None, **kwargs): layer_norm_training = kwargs.get('layer_norm_training', None) if layer_norm_training is not None: self.layer_norm1.training = layer_norm_training self.layer_norm2.training = layer_norm_training residual = x x = self.layer_norm1(x) if self.use_rope: x = self.self_attn(x, key_padding_mask=encoder_padding_mask) else: x = x.transpose(0, 1) x, _, = self.self_attn( query=x, key=x, value=x, key_padding_mask=encoder_padding_mask ) x = x.transpose(0, 1) x = F.dropout(x, self.dropout, training=self.training) x = residual + x x = x * (1 - encoder_padding_mask.float())[..., None] residual = x x = self.layer_norm2(x) x = self.ffn(x) x = F.dropout(x, self.dropout, training=self.training) x = residual + x x = x * (1 - encoder_padding_mask.float())[..., None] return x class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb