Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from torchaudio.models import Conformer | |
from huggingface_hub import PyTorchModelHubMixin | |
from .config import ( | |
N_MELS, | |
CNN_CH, | |
N_HEADS, | |
D_MODEL, | |
FF_DIM, | |
N_LAYERS, | |
DROPOUT, | |
DEPTHWISE_CONV_KERNEL_SIZE, | |
HIDDEN_DIM, | |
DEVICE, | |
) | |
class TaikoConformer7(nn.Module, PyTorchModelHubMixin): | |
def __init__(self): | |
super().__init__() | |
# 1) CNN frontend: frequency-only pooling | |
self.cnn = nn.Sequential( | |
nn.Conv2d(1, CNN_CH, 3, stride=(2, 1), padding=1), | |
nn.BatchNorm2d(CNN_CH), | |
nn.GELU(), | |
nn.Dropout2d(DROPOUT), | |
nn.Conv2d(CNN_CH, CNN_CH, 3, stride=(2, 1), padding=1), | |
nn.BatchNorm2d(CNN_CH), | |
nn.GELU(), | |
nn.Dropout2d(DROPOUT), | |
) | |
feat_dim = CNN_CH * (N_MELS // 4) | |
# 2) Linear projection to model dimension | |
self.proj = nn.Linear(feat_dim, D_MODEL) | |
# 3) FiLM conditioning for notes_per_second, difficulty, and level | |
self.film_nps = nn.Linear(1, 2 * D_MODEL) | |
self.film_difficulty = nn.Linear( | |
1, 2 * D_MODEL | |
) # Assuming difficulty is a single scalar | |
self.film_level = nn.Linear(1, 2 * D_MODEL) # Assuming level is a single scalar | |
# 4) Conformer encoder | |
self.encoder = Conformer( | |
input_dim=D_MODEL, | |
num_heads=N_HEADS, | |
ffn_dim=FF_DIM, | |
num_layers=N_LAYERS, | |
depthwise_conv_kernel_size=DEPTHWISE_CONV_KERNEL_SIZE, | |
dropout=DROPOUT, | |
use_group_norm=False, | |
convolution_first=False, | |
) | |
# 5) Presence regressor head | |
self.presence_regressor = nn.Sequential( | |
nn.Dropout(DROPOUT), | |
nn.Linear(D_MODEL, HIDDEN_DIM), | |
nn.GELU(), | |
nn.Dropout(DROPOUT), | |
nn.Linear(HIDDEN_DIM, 3), # Don, Ka, DrumRoll energy | |
nn.Sigmoid(), # Output between 0 and 1 | |
) | |
# 6) Initialize weights | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight, nonlinearity="relu") | |
elif isinstance(m, nn.Linear): | |
nn.init.xavier_uniform_(m.weight) | |
if m.bias is not None: | |
nn.init.zeros_(m.bias) | |
def forward( | |
self, | |
mel: torch.Tensor, | |
lengths: torch.Tensor, | |
notes_per_second: torch.Tensor, | |
difficulty: torch.Tensor, | |
level: torch.Tensor, | |
): | |
""" | |
Args: | |
mel: (B, 1, N_MELS, T_mel) | |
lengths: (B,) lengths after CNN | |
notes_per_second: (B,) stream of control values | |
difficulty: (B,) difficulty values | |
level: (B,) level values | |
Returns: | |
Dict with: | |
'presence': (B, T_cnn_out, 3) # Corrected from 4 to 3 | |
'lengths': lengths | |
""" | |
# CNN frontend | |
x = self.cnn(mel) # (B, C, F, T) | |
B, C, F, T = x.size() | |
x = x.permute(0, 3, 1, 2).reshape(B, T, C * F) | |
# Project to model dimension | |
x = self.proj(x) # (B, T, D_MODEL) | |
# FiLM conditioning | |
nps = notes_per_second.unsqueeze(-1).float() # (B, 1) | |
gamma_beta_nps = self.film_nps(nps) # (B, 2*D_MODEL) | |
gamma_nps, beta_nps = gamma_beta_nps.chunk(2, dim=-1) | |
x = gamma_nps.unsqueeze(1) * x + beta_nps.unsqueeze(1) | |
diff = difficulty.unsqueeze(-1).float() # (B, 1) | |
gamma_beta_diff = self.film_difficulty(diff) # (B, 2*D_MODEL) | |
gamma_diff, beta_diff = gamma_beta_diff.chunk(2, dim=-1) | |
x = gamma_diff.unsqueeze(1) * x + beta_diff.unsqueeze(1) | |
lvl = level.unsqueeze(-1).float() # (B, 1) | |
gamma_beta_lvl = self.film_level(lvl) # (B, 2*D_MODEL) | |
gamma_lvl, beta_lvl = gamma_beta_lvl.chunk(2, dim=-1) | |
x = gamma_lvl.unsqueeze(1) * x + beta_lvl.unsqueeze(1) | |
# Conformer encoder | |
x, _ = self.encoder(x, lengths=lengths) | |
# Presence prediction | |
presence = self.presence_regressor(x) | |
return {"presence": presence, "lengths": lengths} | |
if __name__ == "__main__": | |
model = TaikoConformer7().to(device=DEVICE) | |
print(model) | |
for name, param in model.named_parameters(): | |
if param.requires_grad: | |
print(f"{name}: {param.numel():,}") | |
params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
print(f"Total parameters: {params / 1e6:.2f}M") | |
batch_size = 4 | |
mel_time_steps = 1024 | |
input_mel = torch.randn(batch_size, 1, N_MELS, mel_time_steps).to(DEVICE) | |
conformer_lengths = torch.tensor( | |
[mel_time_steps] * batch_size, dtype=torch.long | |
).to(DEVICE) | |
notes_per_second_input = torch.tensor([10.0] * batch_size, dtype=torch.float32).to( | |
DEVICE | |
) | |
difficulty_input = torch.tensor([1.0] * batch_size, dtype=torch.float32).to( | |
DEVICE | |
) # Example difficulty | |
level_input = torch.tensor([5.0] * batch_size, dtype=torch.float32).to( | |
DEVICE | |
) # Example level | |
output = model( | |
input_mel, | |
conformer_lengths, | |
notes_per_second_input, | |
difficulty_input, | |
level_input, | |
) | |
print("Output shapes:") | |
for key, value in output.items(): | |
print(f"{key}: {value.shape}") | |