tc5-exp / tc5 /preprocess.py
JacobLinCool's picture
Implement TaikoConformer7 model, loss function, preprocessing, and training pipeline
812b01c
import math
import torch
import torch.nn as nn
import torchaudio
from torchaudio.transforms import FrequencyMasking
from .config import N_TYPES, SAMPLE_RATE, N_MELS, HOP_LENGTH, TIME_SUB
from .model import TaikoConformer5
mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=SAMPLE_RATE,
n_mels=N_MELS,
hop_length=HOP_LENGTH,
n_fft=2048,
)
freq_mask = FrequencyMasking(freq_mask_param=15)
def preprocess(example, difficulty="oni"):
wav_tensor = example["audio"]["array"]
sr = example["audio"]["sampling_rate"]
# 1) load & resample
if sr != SAMPLE_RATE:
wav_tensor = torchaudio.functional.resample(wav_tensor, sr, SAMPLE_RATE)
# normalize audio
wav_tensor = wav_tensor / (wav_tensor.abs().max() + 1e-8)
# add random Gaussian noise
if torch.rand(1).item() < 0.5:
wav_tensor = wav_tensor + 0.005 * torch.randn_like(wav_tensor)
# 2) mel: (1, N_MELS, T)
mel = mel_transform(wav_tensor).unsqueeze(0)
# apply SpecAugment
# we don't use time masking since we don't want model to predict notes when they are masked
mel = freq_mask(mel)
_, _, T = mel.shape
# 3) build label sequence of length ceil(T / TIME_SUB)
T_sub = math.ceil(T / TIME_SUB)
# Initialize energy-based labels for Don, Ka, Drumroll
don_labels = torch.zeros(T_sub, dtype=torch.float32)
ka_labels = torch.zeros(T_sub, dtype=torch.float32)
drumroll_labels = torch.zeros(T_sub, dtype=torch.float32)
# Define exponential decay tail parameters
tail_length = 40 # number of frames for decay tail
decay_rate = 8.0 # decay rate parameter, adjust as needed
tail_kernel = torch.exp(
-torch.arange(0, tail_length, dtype=torch.float32) / decay_rate
)
fps = SAMPLE_RATE / HOP_LENGTH
num_valid_notes = 0
for onset in example[difficulty]:
typ, t_start, t_end, *_ = onset
# Assuming N_TYPES in config is appropriately set (e.g., 7 or more)
if typ < 1 or typ > N_TYPES: # Filter out invalid types
continue
num_valid_notes += 1
f = int(round(t_start.item() * fps))
idx = f // TIME_SUB
if 0 <= idx < T_sub:
# Apply exponential decay kernel to the corresponding energy channel
# Type 1 and 3 are Don
if typ == 1 or typ == 3:
for i, val in enumerate(tail_kernel):
target_idx = idx + i
if 0 <= target_idx < T_sub:
don_labels[target_idx] = max(
don_labels[target_idx].item(), val.item()
)
# Type 2 and 4 are Ka
elif typ == 2 or typ == 4:
for i, val in enumerate(tail_kernel):
target_idx = idx + i
if 0 <= target_idx < T_sub:
ka_labels[target_idx] = max(
ka_labels[target_idx].item(), val.item()
)
# Type 5, 6, 7 are Drumroll
elif typ >= 5 and typ <= 7:
f_end = int(round(t_end.item() * fps))
idx_end = f_end // TIME_SUB
for dr in range(idx, idx_end):
if 0 <= dr < T_sub:
drumroll_labels[dr] = 1.0
for i, val in enumerate(tail_kernel):
target_idx = idx_end + i
if 0 <= target_idx < T_sub:
drumroll_labels[target_idx] = max(
drumroll_labels[target_idx].item(), val.item()
)
duration_seconds = wav_tensor.shape[-1] / SAMPLE_RATE
nps = num_valid_notes / duration_seconds if duration_seconds > 0 else 0.0
print(
f"Processed {num_valid_notes} notes in {duration_seconds:.2f} seconds, NPS: {nps:.2f}"
)
return {
"mel": mel,
"don_labels": don_labels,
"ka_labels": ka_labels,
"drumroll_labels": drumroll_labels,
"nps": torch.tensor(nps, dtype=torch.float32),
"duration_seconds": torch.tensor(duration_seconds, dtype=torch.float32),
}
def collate_fn(batch):
mels_list = [b["mel"].squeeze(0).transpose(0, 1) for b in batch] # (T, N_MELS)
# Extract new energy-based labels
don_labels_list = [b["don_labels"] for b in batch]
ka_labels_list = [b["ka_labels"] for b in batch]
drumroll_labels_list = [b["drumroll_labels"] for b in batch]
nps_list = [b["nps"] for b in batch] # Extract NPS
durations_list = [b["duration_seconds"] for b in batch] # Extract durations
# Pad mels
padded_mels = nn.utils.rnn.pad_sequence(
mels_list, batch_first=True
) # (B, T_max, N_MELS)
# Reshape for CNN: (B, 1, N_MELS, T_max)
reshaped_mels = padded_mels.transpose(1, 2).unsqueeze(1)
# Simulate CNN time downsampling to get output lengths
dummy_model_for_shape_inference = TaikoConformer5()
dummy_cnn = dummy_model_for_shape_inference.cnn
with torch.no_grad():
cnn_out = dummy_cnn(reshaped_mels) # Use reshaped_mels that has batch dim
_, _, _, T_cnn = cnn_out.shape
padded_don_labels = []
padded_ka_labels = []
padded_drumroll_labels = []
# lengths = [] # This was for original presence/type labels, conformer_input_lengths is used for model
for i in range(len(batch)):
d_labels = don_labels_list[i]
k_labels = ka_labels_list[i]
dr_labels = drumroll_labels_list[i]
item_original_T_sub = d_labels.shape[
0
] # Assuming all label types have same original length
out_len = T_cnn # Target length for labels is T_cnn
# Pad or truncate don_labels
if item_original_T_sub < out_len:
pad_d = torch.full(
(out_len - item_original_T_sub,),
0, # Pad with 0 for energy labels
dtype=d_labels.dtype,
device=d_labels.device,
)
padded_d = torch.cat([d_labels, pad_d], dim=0)
else:
padded_d = d_labels[:out_len]
padded_don_labels.append(padded_d)
# Pad or truncate ka_labels
if item_original_T_sub < out_len:
pad_k = torch.full(
(out_len - item_original_T_sub,),
0, # Pad with 0 for energy labels
dtype=k_labels.dtype,
device=k_labels.device,
)
padded_k = torch.cat([k_labels, pad_k], dim=0)
else:
padded_k = k_labels[:out_len]
padded_ka_labels.append(padded_k)
# Pad or truncate drumroll_labels
if item_original_T_sub < out_len:
pad_dr = torch.full(
(out_len - item_original_T_sub,),
0, # Pad with 0 for energy labels
dtype=dr_labels.dtype,
device=dr_labels.device,
)
padded_dr = torch.cat([dr_labels, pad_dr], dim=0)
else:
padded_dr = dr_labels[:out_len]
padded_drumroll_labels.append(padded_dr)
# For Conformer input lengths: lengths of mel sequences after CNN subsampling
# (Assuming CNN does not subsample in time, T_cnn is effectively T_mel_padded)
# The `lengths` for the Conformer should be based on the mel input to the conformer part.
# The existing calculation for conformer_input_lengths seems to relate to TIME_SUB.
# If the Conformer input itself is not subsampled by TIME_SUB, this might need review.
# For now, keeping the existing conformer_input_lengths logic as it's outside the scope of label change.
conformer_input_lengths = [
math.ceil(mels_list[i].shape[0] / TIME_SUB) for i in range(len(batch))
]
conformer_input_lengths = torch.tensor(
[min(l, T_cnn) for l in conformer_input_lengths], dtype=torch.long
)
return {
"mel": reshaped_mels,
"don_labels": torch.stack(padded_don_labels),
"ka_labels": torch.stack(padded_ka_labels),
"drumroll_labels": torch.stack(padded_drumroll_labels),
"lengths": conformer_input_lengths, # These are for the Conformer model
"nps": torch.stack(nps_list),
"durations": torch.stack(durations_list),
}