Spaces:
Runtime error
Runtime error
from torch import nn | |
import torch | |
import torch.nn.functional as F | |
from motion.model.layer_norm_fp16 import RMSNorm, LayerNorm | |
class ResConv1DBlock(nn.Module): | |
def __init__(self, n_in, n_state, bias, norm_type, activate_type): | |
super().__init__() | |
if activate_type.lower() == "silu": | |
activate = nn.SiLU() | |
elif activate_type.lower() == "relu": | |
activate = nn.ReLU() | |
elif activate_type.lower() == "gelu": | |
activate = nn.GELU() | |
elif activate_type.lower() == "mish": | |
activate = nn.Mish() | |
if norm_type.lower() == "rmsnorm": | |
norm = RMSNorm | |
elif norm_type.lower() == "layernorm": | |
norm = LayerNorm | |
self.norm1 = norm(n_state) | |
self.norm2 = norm(n_in) | |
self.relu1 = activate | |
self.relu2 = activate | |
self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, 1, bias=bias) | |
self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0, bias=bias) | |
def forward(self, x): | |
x_orig = x | |
x = self.conv1(x) | |
x = self.norm1(x.transpose(-2, -1)) | |
x = self.relu1(x.transpose(-2, -1)) | |
x = self.conv2(x) | |
x = self.norm2(x.transpose(-2, -1)) | |
x = self.relu2(x.transpose(-2, -1)) | |
x = x + x_orig | |
return x | |
class Encoder_Block(nn.Module): | |
def __init__(self, begin_channel=263, latent_dim=512, num_layers=6, TN=1, bias=False, norm_type="rmsnorm", activate_type="silu"): | |
super(Encoder_Block, self).__init__() | |
self.layers = [] | |
begin_channel = begin_channel | |
target_channel = latent_dim | |
if activate_type.lower() == "silu": | |
activate = nn.SiLU() | |
elif activate_type.lower() == "relu": | |
activate = nn.ReLU() | |
elif activate_type.lower() == "gelu": | |
activate = nn.GELU() | |
elif activate_type.lower() == "mish": | |
activate = nn.Mish() | |
self.layers.append(nn.Conv1d(begin_channel, target_channel, 3, 2, 1, bias=bias)) | |
self.layers.append(activate) | |
for _ in range(num_layers): ### 196 -> 98 -> 49 -> 24 -> 12 -> 6 -> 3 | |
self.layers.append(nn.Conv1d(target_channel, target_channel, 3, 2, 1, bias=bias)) | |
self.layers.append(activate) | |
self.layers.append(ResConv1DBlock(target_channel, target_channel, bias, norm_type, activate_type)) | |
self.layers = nn.Sequential(*self.layers) | |
self.maxpool = nn.AdaptiveMaxPool1d(TN) | |
def forward(self, x): | |
bs, njoints, nfeats, nframes = x.shape | |
reshaped_x = x.reshape(bs, njoints * nfeats, nframes) ### [bs, 263, seq] | |
res1 = self.layers(reshaped_x) #### [bs, 512, 1] | |
res2 = self.maxpool(res1) | |
res3 = res2.permute(2, 0, 1) | |
return res3 |