Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,392 Bytes
812b01c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import torch
import torch.nn as nn
from torchaudio.models import Conformer
from huggingface_hub import PyTorchModelHubMixin
from .config import (
N_MELS,
CNN_CH,
N_HEADS,
D_MODEL,
FF_DIM,
N_LAYERS,
DROPOUT,
DEPTHWISE_CONV_KERNEL_SIZE,
HIDDEN_DIM,
DEVICE,
)
class TaikoConformer7(nn.Module, PyTorchModelHubMixin):
def __init__(self):
super().__init__()
# 1) CNN frontend: frequency-only pooling
self.cnn = nn.Sequential(
nn.Conv2d(1, CNN_CH, 3, stride=(2, 1), padding=1),
nn.BatchNorm2d(CNN_CH),
nn.GELU(),
nn.Dropout2d(DROPOUT),
nn.Conv2d(CNN_CH, CNN_CH, 3, stride=(2, 1), padding=1),
nn.BatchNorm2d(CNN_CH),
nn.GELU(),
nn.Dropout2d(DROPOUT),
)
feat_dim = CNN_CH * (N_MELS // 4)
# 2) Linear projection to model dimension
self.proj = nn.Linear(feat_dim, D_MODEL)
# 3) FiLM conditioning for notes_per_second, difficulty, and level
self.film_nps = nn.Linear(1, 2 * D_MODEL)
self.film_difficulty = nn.Linear(
1, 2 * D_MODEL
) # Assuming difficulty is a single scalar
self.film_level = nn.Linear(1, 2 * D_MODEL) # Assuming level is a single scalar
# 4) Conformer encoder
self.encoder = Conformer(
input_dim=D_MODEL,
num_heads=N_HEADS,
ffn_dim=FF_DIM,
num_layers=N_LAYERS,
depthwise_conv_kernel_size=DEPTHWISE_CONV_KERNEL_SIZE,
dropout=DROPOUT,
use_group_norm=False,
convolution_first=False,
)
# 5) Presence regressor head
self.presence_regressor = nn.Sequential(
nn.Dropout(DROPOUT),
nn.Linear(D_MODEL, HIDDEN_DIM),
nn.GELU(),
nn.Dropout(DROPOUT),
nn.Linear(HIDDEN_DIM, 3), # Don, Ka, DrumRoll energy
nn.Sigmoid(), # Output between 0 and 1
)
# 6) Initialize weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(
self,
mel: torch.Tensor,
lengths: torch.Tensor,
notes_per_second: torch.Tensor,
difficulty: torch.Tensor,
level: torch.Tensor,
):
"""
Args:
mel: (B, 1, N_MELS, T_mel)
lengths: (B,) lengths after CNN
notes_per_second: (B,) stream of control values
difficulty: (B,) difficulty values
level: (B,) level values
Returns:
Dict with:
'presence': (B, T_cnn_out, 3) # Corrected from 4 to 3
'lengths': lengths
"""
# CNN frontend
x = self.cnn(mel) # (B, C, F, T)
B, C, F, T = x.size()
x = x.permute(0, 3, 1, 2).reshape(B, T, C * F)
# Project to model dimension
x = self.proj(x) # (B, T, D_MODEL)
# FiLM conditioning
nps = notes_per_second.unsqueeze(-1).float() # (B, 1)
gamma_beta_nps = self.film_nps(nps) # (B, 2*D_MODEL)
gamma_nps, beta_nps = gamma_beta_nps.chunk(2, dim=-1)
x = gamma_nps.unsqueeze(1) * x + beta_nps.unsqueeze(1)
diff = difficulty.unsqueeze(-1).float() # (B, 1)
gamma_beta_diff = self.film_difficulty(diff) # (B, 2*D_MODEL)
gamma_diff, beta_diff = gamma_beta_diff.chunk(2, dim=-1)
x = gamma_diff.unsqueeze(1) * x + beta_diff.unsqueeze(1)
lvl = level.unsqueeze(-1).float() # (B, 1)
gamma_beta_lvl = self.film_level(lvl) # (B, 2*D_MODEL)
gamma_lvl, beta_lvl = gamma_beta_lvl.chunk(2, dim=-1)
x = gamma_lvl.unsqueeze(1) * x + beta_lvl.unsqueeze(1)
# Conformer encoder
x, _ = self.encoder(x, lengths=lengths)
# Presence prediction
presence = self.presence_regressor(x)
return {"presence": presence, "lengths": lengths}
if __name__ == "__main__":
model = TaikoConformer7().to(device=DEVICE)
print(model)
for name, param in model.named_parameters():
if param.requires_grad:
print(f"{name}: {param.numel():,}")
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {params / 1e6:.2f}M")
batch_size = 4
mel_time_steps = 1024
input_mel = torch.randn(batch_size, 1, N_MELS, mel_time_steps).to(DEVICE)
conformer_lengths = torch.tensor(
[mel_time_steps] * batch_size, dtype=torch.long
).to(DEVICE)
notes_per_second_input = torch.tensor([10.0] * batch_size, dtype=torch.float32).to(
DEVICE
)
difficulty_input = torch.tensor([1.0] * batch_size, dtype=torch.float32).to(
DEVICE
) # Example difficulty
level_input = torch.tensor([5.0] * batch_size, dtype=torch.float32).to(
DEVICE
) # Example level
output = model(
input_mel,
conformer_lengths,
notes_per_second_input,
difficulty_input,
level_input,
)
print("Output shapes:")
for key, value in output.items():
print(f"{key}: {value.shape}")
|