import torch import torch.nn as nn from torchaudio.models import Conformer from huggingface_hub import PyTorchModelHubMixin from .config import ( N_MELS, CNN_CH, N_HEADS, D_MODEL, FF_DIM, N_LAYERS, DROPOUT, DEPTHWISE_CONV_KERNEL_SIZE, HIDDEN_DIM, DEVICE, ) class TaikoConformer7(nn.Module, PyTorchModelHubMixin): def __init__(self): super().__init__() # 1) CNN frontend: frequency-only pooling self.cnn = nn.Sequential( nn.Conv2d(1, CNN_CH, 3, stride=(2, 1), padding=1), nn.BatchNorm2d(CNN_CH), nn.GELU(), nn.Dropout2d(DROPOUT), nn.Conv2d(CNN_CH, CNN_CH, 3, stride=(2, 1), padding=1), nn.BatchNorm2d(CNN_CH), nn.GELU(), nn.Dropout2d(DROPOUT), ) feat_dim = CNN_CH * (N_MELS // 4) # 2) Linear projection to model dimension self.proj = nn.Linear(feat_dim, D_MODEL) # 3) FiLM conditioning for notes_per_second, difficulty, and level self.film_nps = nn.Linear(1, 2 * D_MODEL) self.film_difficulty = nn.Linear( 1, 2 * D_MODEL ) # Assuming difficulty is a single scalar self.film_level = nn.Linear(1, 2 * D_MODEL) # Assuming level is a single scalar # 4) Conformer encoder self.encoder = Conformer( input_dim=D_MODEL, num_heads=N_HEADS, ffn_dim=FF_DIM, num_layers=N_LAYERS, depthwise_conv_kernel_size=DEPTHWISE_CONV_KERNEL_SIZE, dropout=DROPOUT, use_group_norm=False, convolution_first=False, ) # 5) Presence regressor head self.presence_regressor = nn.Sequential( nn.Dropout(DROPOUT), nn.Linear(D_MODEL, HIDDEN_DIM), nn.GELU(), nn.Dropout(DROPOUT), nn.Linear(HIDDEN_DIM, 3), # Don, Ka, DrumRoll energy nn.Sigmoid(), # Output between 0 and 1 ) # 6) Initialize weights for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, nonlinearity="relu") elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) def forward( self, mel: torch.Tensor, lengths: torch.Tensor, notes_per_second: torch.Tensor, difficulty: torch.Tensor, level: torch.Tensor, ): """ Args: mel: (B, 1, N_MELS, T_mel) lengths: (B,) lengths after CNN notes_per_second: (B,) stream of control values difficulty: (B,) difficulty values level: (B,) level values Returns: Dict with: 'presence': (B, T_cnn_out, 3) # Corrected from 4 to 3 'lengths': lengths """ # CNN frontend x = self.cnn(mel) # (B, C, F, T) B, C, F, T = x.size() x = x.permute(0, 3, 1, 2).reshape(B, T, C * F) # Project to model dimension x = self.proj(x) # (B, T, D_MODEL) # FiLM conditioning nps = notes_per_second.unsqueeze(-1).float() # (B, 1) gamma_beta_nps = self.film_nps(nps) # (B, 2*D_MODEL) gamma_nps, beta_nps = gamma_beta_nps.chunk(2, dim=-1) x = gamma_nps.unsqueeze(1) * x + beta_nps.unsqueeze(1) diff = difficulty.unsqueeze(-1).float() # (B, 1) gamma_beta_diff = self.film_difficulty(diff) # (B, 2*D_MODEL) gamma_diff, beta_diff = gamma_beta_diff.chunk(2, dim=-1) x = gamma_diff.unsqueeze(1) * x + beta_diff.unsqueeze(1) lvl = level.unsqueeze(-1).float() # (B, 1) gamma_beta_lvl = self.film_level(lvl) # (B, 2*D_MODEL) gamma_lvl, beta_lvl = gamma_beta_lvl.chunk(2, dim=-1) x = gamma_lvl.unsqueeze(1) * x + beta_lvl.unsqueeze(1) # Conformer encoder x, _ = self.encoder(x, lengths=lengths) # Presence prediction presence = self.presence_regressor(x) return {"presence": presence, "lengths": lengths} if __name__ == "__main__": model = TaikoConformer7().to(device=DEVICE) print(model) for name, param in model.named_parameters(): if param.requires_grad: print(f"{name}: {param.numel():,}") params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total parameters: {params / 1e6:.2f}M") batch_size = 4 mel_time_steps = 1024 input_mel = torch.randn(batch_size, 1, N_MELS, mel_time_steps).to(DEVICE) conformer_lengths = torch.tensor( [mel_time_steps] * batch_size, dtype=torch.long ).to(DEVICE) notes_per_second_input = torch.tensor([10.0] * batch_size, dtype=torch.float32).to( DEVICE ) difficulty_input = torch.tensor([1.0] * batch_size, dtype=torch.float32).to( DEVICE ) # Example difficulty level_input = torch.tensor([5.0] * batch_size, dtype=torch.float32).to( DEVICE ) # Example level output = model( input_mel, conformer_lengths, notes_per_second_input, difficulty_input, level_input, ) print("Output shapes:") for key, value in output.items(): print(f"{key}: {value.shape}")