tc5-exp / tc7 /model.py
JacobLinCool's picture
Implement TaikoConformer7 model, loss function, preprocessing, and training pipeline
812b01c
import torch
import torch.nn as nn
from torchaudio.models import Conformer
from huggingface_hub import PyTorchModelHubMixin
from .config import (
N_MELS,
CNN_CH,
N_HEADS,
D_MODEL,
FF_DIM,
N_LAYERS,
DROPOUT,
DEPTHWISE_CONV_KERNEL_SIZE,
HIDDEN_DIM,
DEVICE,
)
class TaikoConformer7(nn.Module, PyTorchModelHubMixin):
def __init__(self):
super().__init__()
# 1) CNN frontend: frequency-only pooling
self.cnn = nn.Sequential(
nn.Conv2d(1, CNN_CH, 3, stride=(2, 1), padding=1),
nn.BatchNorm2d(CNN_CH),
nn.GELU(),
nn.Dropout2d(DROPOUT),
nn.Conv2d(CNN_CH, CNN_CH, 3, stride=(2, 1), padding=1),
nn.BatchNorm2d(CNN_CH),
nn.GELU(),
nn.Dropout2d(DROPOUT),
)
feat_dim = CNN_CH * (N_MELS // 4)
# 2) Linear projection to model dimension
self.proj = nn.Linear(feat_dim, D_MODEL)
# 3) FiLM conditioning for notes_per_second, difficulty, and level
self.film_nps = nn.Linear(1, 2 * D_MODEL)
self.film_difficulty = nn.Linear(
1, 2 * D_MODEL
) # Assuming difficulty is a single scalar
self.film_level = nn.Linear(1, 2 * D_MODEL) # Assuming level is a single scalar
# 4) Conformer encoder
self.encoder = Conformer(
input_dim=D_MODEL,
num_heads=N_HEADS,
ffn_dim=FF_DIM,
num_layers=N_LAYERS,
depthwise_conv_kernel_size=DEPTHWISE_CONV_KERNEL_SIZE,
dropout=DROPOUT,
use_group_norm=False,
convolution_first=False,
)
# 5) Presence regressor head
self.presence_regressor = nn.Sequential(
nn.Dropout(DROPOUT),
nn.Linear(D_MODEL, HIDDEN_DIM),
nn.GELU(),
nn.Dropout(DROPOUT),
nn.Linear(HIDDEN_DIM, 3), # Don, Ka, DrumRoll energy
nn.Sigmoid(), # Output between 0 and 1
)
# 6) Initialize weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(
self,
mel: torch.Tensor,
lengths: torch.Tensor,
notes_per_second: torch.Tensor,
difficulty: torch.Tensor,
level: torch.Tensor,
):
"""
Args:
mel: (B, 1, N_MELS, T_mel)
lengths: (B,) lengths after CNN
notes_per_second: (B,) stream of control values
difficulty: (B,) difficulty values
level: (B,) level values
Returns:
Dict with:
'presence': (B, T_cnn_out, 3) # Corrected from 4 to 3
'lengths': lengths
"""
# CNN frontend
x = self.cnn(mel) # (B, C, F, T)
B, C, F, T = x.size()
x = x.permute(0, 3, 1, 2).reshape(B, T, C * F)
# Project to model dimension
x = self.proj(x) # (B, T, D_MODEL)
# FiLM conditioning
nps = notes_per_second.unsqueeze(-1).float() # (B, 1)
gamma_beta_nps = self.film_nps(nps) # (B, 2*D_MODEL)
gamma_nps, beta_nps = gamma_beta_nps.chunk(2, dim=-1)
x = gamma_nps.unsqueeze(1) * x + beta_nps.unsqueeze(1)
diff = difficulty.unsqueeze(-1).float() # (B, 1)
gamma_beta_diff = self.film_difficulty(diff) # (B, 2*D_MODEL)
gamma_diff, beta_diff = gamma_beta_diff.chunk(2, dim=-1)
x = gamma_diff.unsqueeze(1) * x + beta_diff.unsqueeze(1)
lvl = level.unsqueeze(-1).float() # (B, 1)
gamma_beta_lvl = self.film_level(lvl) # (B, 2*D_MODEL)
gamma_lvl, beta_lvl = gamma_beta_lvl.chunk(2, dim=-1)
x = gamma_lvl.unsqueeze(1) * x + beta_lvl.unsqueeze(1)
# Conformer encoder
x, _ = self.encoder(x, lengths=lengths)
# Presence prediction
presence = self.presence_regressor(x)
return {"presence": presence, "lengths": lengths}
if __name__ == "__main__":
model = TaikoConformer7().to(device=DEVICE)
print(model)
for name, param in model.named_parameters():
if param.requires_grad:
print(f"{name}: {param.numel():,}")
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {params / 1e6:.2f}M")
batch_size = 4
mel_time_steps = 1024
input_mel = torch.randn(batch_size, 1, N_MELS, mel_time_steps).to(DEVICE)
conformer_lengths = torch.tensor(
[mel_time_steps] * batch_size, dtype=torch.long
).to(DEVICE)
notes_per_second_input = torch.tensor([10.0] * batch_size, dtype=torch.float32).to(
DEVICE
)
difficulty_input = torch.tensor([1.0] * batch_size, dtype=torch.float32).to(
DEVICE
) # Example difficulty
level_input = torch.tensor([5.0] * batch_size, dtype=torch.float32).to(
DEVICE
) # Example level
output = model(
input_mel,
conformer_lengths,
notes_per_second_input,
difficulty_input,
level_input,
)
print("Output shapes:")
for key, value in output.items():
print(f"{key}: {value.shape}")