|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
|
|
class ChannelLastConv1d(nn.Conv1d): |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
x = x.permute(0, 2, 1) |
|
x = super().forward(x) |
|
x = x.permute(0, 2, 1) |
|
return x |
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
hidden_dim: int, |
|
multiple_of: int = 256, |
|
): |
|
""" |
|
Initialize the FeedForward module. |
|
|
|
Args: |
|
dim (int): Input dimension. |
|
hidden_dim (int): Hidden dimension of the feedforward layer. |
|
multiple_of (int): Value to ensure hidden dimension is a multiple of this value. |
|
|
|
Attributes: |
|
w1 (ColumnParallelLinear): Linear transformation for the first layer. |
|
w2 (RowParallelLinear): Linear transformation for the second layer. |
|
w3 (ColumnParallelLinear): Linear transformation for the third layer. |
|
|
|
""" |
|
super().__init__() |
|
hidden_dim = int(2 * hidden_dim / 3) |
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
|
|
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
|
self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
|
def forward(self, x): |
|
return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
|
class ConvMLP(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
hidden_dim: int, |
|
multiple_of: int = 256, |
|
kernel_size: int = 3, |
|
padding: int = 1, |
|
): |
|
""" |
|
Initialize the FeedForward module. |
|
|
|
Args: |
|
dim (int): Input dimension. |
|
hidden_dim (int): Hidden dimension of the feedforward layer. |
|
multiple_of (int): Value to ensure hidden dimension is a multiple of this value. |
|
|
|
Attributes: |
|
w1 (ColumnParallelLinear): Linear transformation for the first layer. |
|
w2 (RowParallelLinear): Linear transformation for the second layer. |
|
w3 (ColumnParallelLinear): Linear transformation for the third layer. |
|
|
|
""" |
|
super().__init__() |
|
hidden_dim = int(2 * hidden_dim / 3) |
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
|
|
|
self.w1 = ChannelLastConv1d(dim, |
|
hidden_dim, |
|
bias=False, |
|
kernel_size=kernel_size, |
|
padding=padding) |
|
self.w2 = ChannelLastConv1d(hidden_dim, |
|
dim, |
|
bias=False, |
|
kernel_size=kernel_size, |
|
padding=padding) |
|
self.w3 = ChannelLastConv1d(dim, |
|
hidden_dim, |
|
bias=False, |
|
kernel_size=kernel_size, |
|
padding=padding) |
|
|
|
def forward(self, x): |
|
return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
if __name__ == "__main__": |
|
latent_dim = 20 |
|
hidden_dim = 64 * 7 |
|
conv1d = ChannelLastConv1d( |
|
in_channels = latent_dim, |
|
out_channels = hidden_dim, |
|
kernel_size = 7, |
|
padding = 3 |
|
) |
|
num_params = sum([p.numel() for p in conv1d.parameters()]) |
|
print(conv1d) |
|
print(f"Num params for conv1d: {num_params}") |
|
|
|
B, T, D = 128, 250, 20 |
|
x = torch.randn(B, T, D) |
|
h = conv1d(x) |
|
|
|
conv_mlp = ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3) |
|
num_params = sum([p.numel() for p in conv_mlp.parameters()]) |
|
print(conv_mlp) |
|
print(f"Nim params for convmlp: {num_params}") |
|
y = conv_mlp(h) |
|
print(y.shape) |