CantusSVS-hf / modules /commons /common_layers.py
liampond
Clean deploy snapshot
c42fe7e
raw
history blame
10.4 kB
from __future__ import annotations
import math
import numpy as np
import torch
import torch.nn.functional as F
import torch.onnx.operators
from torch import nn
from torch.nn import LayerNorm, MultiheadAttention, ReLU, GELU, SiLU
import utils
class NormalInitEmbedding(torch.nn.Embedding):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int | None = None,
*args,
**kwargs
):
super().__init__(num_embeddings, embedding_dim, *args, padding_idx=padding_idx, **kwargs)
nn.init.normal_(self.weight, mean=0, std=self.embedding_dim ** -0.5)
if padding_idx is not None:
nn.init.constant_(self.weight[padding_idx], 0)
class XavierUniformInitLinear(torch.nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
*args,
bias: bool = True,
**kwargs
):
super().__init__(in_features, out_features, *args, bias=bias, **kwargs)
nn.init.xavier_uniform_(self.weight)
if bias:
nn.init.constant_(self.bias, 0.)
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored.
"""
def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size,
embedding_dim,
padding_idx,
)
self.register_buffer('_float_tensor', torch.FloatTensor(1))
@staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(self, x, incremental_state=None, timestep=None, positions=None):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = x.shape[:2]
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos,
self.embedding_dim,
self.padding_idx,
)
self.weights = self.weights.to(self._float_tensor)
if incremental_state is not None:
# positions is the same for every token when decoding a single step
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
positions = utils.make_positions(x, self.padding_idx) if positions is None else positions
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
@staticmethod
def max_positions():
"""Maximum number of supported positions."""
return int(1e5) # an arbitrary large number
class SwiGLU(nn.Module):
# Swish-Applies the gated linear unit function.
def __init__(self, dim=-1):
super().__init__()
self.dim = dim
def forward(self, x):
# out, gate = x.chunk(2, dim=self.dim)
# Using torch.split instead of chunk for ONNX export compatibility.
out, gate = torch.split(x, x.size(self.dim) // 2, dim=self.dim)
return out * F.silu(gate)
class TransformerFFNLayer(nn.Module):
def __init__(self, hidden_size, filter_size, kernel_size=1, dropout=0., act='gelu'):
super().__init__()
self.kernel_size = kernel_size
self.dropout = dropout
self.act = act
filter_size_1 = filter_size
if self.act == 'relu':
self.act_fn = ReLU()
elif self.act == 'gelu':
self.act_fn = GELU()
elif self.act == 'swish':
self.act_fn = SiLU()
elif self.act == 'swiglu':
self.act_fn = SwiGLU()
filter_size_1 = filter_size * 2
else:
raise ValueError(f'{act} is not a valid activation')
self.ffn_1 = nn.Conv1d(hidden_size, filter_size_1, kernel_size, padding=kernel_size // 2)
self.ffn_2 = XavierUniformInitLinear(filter_size, hidden_size)
def forward(self, x):
# x: B x T x C
x = self.ffn_1(x.transpose(1, 2)).transpose(1, 2)
x = x * self.kernel_size ** -0.5
x = self.act_fn(x)
x = F.dropout(x, self.dropout, training=self.training)
x = self.ffn_2(x)
return x
class MultiheadSelfAttentionWithRoPE(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1, bias=False, rotary_embed=None):
super().__init__()
assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# Linear layers for Q, K, V projections
self.in_proj = nn.Linear(embed_dim, embed_dim * 3, bias=bias)
# Final linear layer after concatenation
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
# Dropout layer
self.dropout = nn.Dropout(dropout)
# Rotary Embeddings
self.rotary_embed = rotary_embed
def forward(self, x, key_padding_mask=None):
# x: (B, L, C)
# key_padding_mask: (B, L)
batch_size, seq_len, embed_dim = x.size()
# Project inputs to Q, K, V
Q, K, V = torch.split(self.in_proj(x), self.embed_dim, dim=-1)
# Reshape Q, K, V for multi-head attention
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, L, D)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, L, D)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, L, D)
# Apply RoPE
if self.rotary_embed is not None:
Q = self.rotary_embed.rotate_queries_or_keys(Q)
K = self.rotary_embed.rotate_queries_or_keys(K)
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_dim) # (B, H, L, L)
# Apply key padding mask if provided
if key_padding_mask is not None:
# Expand mask to match attention scores shape
mask = key_padding_mask.unsqueeze(1).unsqueeze(1) # (B, 1, 1, L)
scores = scores.masked_fill(mask == 1, -np.inf) # Masked positions are set to -inf
# Compute attention weights
attn_weights = F.softmax(scores, dim=-1) # (B, H, L, L)
attn_weights = self.dropout(attn_weights)
# Apply attention weights to V
attn_output = torch.matmul(attn_weights, V) # (B, H, L, D)
# Reshape and concatenate heads
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim) # (B, L, C)
# Final linear projection
output = self.out_proj(attn_output) # (B, L, C)
return output
class EncSALayer(nn.Module):
def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
relu_dropout=0.1, kernel_size=9, act='gelu', rotary_embed=None):
super().__init__()
self.dropout = dropout
self.layer_norm1 = LayerNorm(c)
if rotary_embed is None:
self.self_attn = MultiheadAttention(
c, num_heads, dropout=attention_dropout, bias=False, batch_first=False
)
self.use_rope = False
else:
self.self_attn = MultiheadSelfAttentionWithRoPE(
c, num_heads, dropout=attention_dropout, bias=False, rotary_embed=rotary_embed
)
self.use_rope = True
self.layer_norm2 = LayerNorm(c)
self.ffn = TransformerFFNLayer(
c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, act=act
)
def forward(self, x, encoder_padding_mask=None, **kwargs):
layer_norm_training = kwargs.get('layer_norm_training', None)
if layer_norm_training is not None:
self.layer_norm1.training = layer_norm_training
self.layer_norm2.training = layer_norm_training
residual = x
x = self.layer_norm1(x)
if self.use_rope:
x = self.self_attn(x, key_padding_mask=encoder_padding_mask)
else:
x = x.transpose(0, 1)
x, _, = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask
)
x = x.transpose(0, 1)
x = F.dropout(x, self.dropout, training=self.training)
x = residual + x
x = x * (1 - encoder_padding_mask.float())[..., None]
residual = x
x = self.layer_norm2(x)
x = self.ffn(x)
x = F.dropout(x, self.dropout, training=self.training)
x = residual + x
x = x * (1 - encoder_padding_mask.float())[..., None]
return x
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb