tc5-exp / tc7 /preprocess.py
JacobLinCool's picture
Implement TaikoConformer7 model, loss function, preprocessing, and training pipeline
812b01c
import math
import torch
import torch.nn as nn
import torchaudio
from torchaudio.transforms import FrequencyMasking
from tja import parse_tja, PyParsingMode
from .config import N_TYPES, SAMPLE_RATE, N_MELS, HOP_LENGTH, TIME_SUB
from .model import TaikoConformer7
mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=SAMPLE_RATE,
n_mels=N_MELS,
hop_length=HOP_LENGTH,
n_fft=2048,
)
freq_mask = FrequencyMasking(freq_mask_param=15)
def preprocess(example, difficulty="oni"):
wav_tensor = example["audio"]["array"]
sr = example["audio"]["sampling_rate"]
# 1) load & resample
if sr != SAMPLE_RATE:
wav_tensor = torchaudio.functional.resample(wav_tensor, sr, SAMPLE_RATE)
# normalize audio
wav_tensor = wav_tensor / (wav_tensor.abs().max() + 1e-8)
# add random Gaussian noise
if torch.rand(1).item() < 0.5:
wav_tensor = wav_tensor + 0.005 * torch.randn_like(wav_tensor)
# 2) mel: (1, N_MELS, T)
mel = mel_transform(wav_tensor).unsqueeze(0)
# apply SpecAugment
mel = freq_mask(mel)
_, _, T = mel.shape
# 3) build label sequence of length ceil(T / TIME_SUB)
T_sub = math.ceil(T / TIME_SUB)
# Initialize energy-based labels for Don, Ka, Drumroll
don_labels = torch.zeros(T_sub, dtype=torch.float32)
ka_labels = torch.zeros(T_sub, dtype=torch.float32)
drumroll_labels = torch.zeros(T_sub, dtype=torch.float32)
sliding_nps_labels = torch.zeros(
T_sub, dtype=torch.float32
) # New label for sliding NPS
# Define exponential decay tail parameters
tail_length = 40 # number of frames for decay tail
decay_rate = 8.0 # decay rate parameter, adjust as needed
tail_kernel = torch.exp(
-torch.arange(0, tail_length, dtype=torch.float32) / decay_rate
)
fps = SAMPLE_RATE / HOP_LENGTH
num_valid_notes = 0
for onset in example[difficulty]:
typ, t_start, t_end, *_ = onset
# Assuming N_TYPES in config is appropriately set (e.g., 7 or more)
if typ < 1 or typ > N_TYPES: # Filter out invalid types
continue
num_valid_notes += 1
exact_frame_start = t_start.item() * fps
# Type 1 and 3 are Don, Type 2 and 4 are Ka
if typ == 1 or typ == 3 or typ == 2 or typ == 4:
exact_hit_time_sub = exact_frame_start / TIME_SUB
current_labels = don_labels if (typ == 1 or typ == 3) else ka_labels
start_points_info = []
rounded_hit_time_sub = round(exact_hit_time_sub)
if (
abs(exact_hit_time_sub - rounded_hit_time_sub) < 1e-6
): # Tolerance for float precision
idx_single = int(rounded_hit_time_sub)
if 0 <= idx_single < T_sub:
start_points_info.append({"idx": idx_single, "weight": 1.0})
else:
idx_floor = math.floor(exact_hit_time_sub)
idx_ceil = idx_floor + 1
frac = exact_hit_time_sub - idx_floor
weight_ceil = frac
weight_floor = 1.0 - frac
if weight_floor > 1e-6 and 0 <= idx_floor < T_sub:
start_points_info.append({"idx": idx_floor, "weight": weight_floor})
if weight_ceil > 1e-6 and 0 <= idx_ceil < T_sub:
start_points_info.append({"idx": idx_ceil, "weight": weight_ceil})
for point_info in start_points_info:
start_idx = point_info["idx"]
weight = point_info["weight"]
for k_idx, kernel_val in enumerate(tail_kernel):
target_idx = start_idx + k_idx
if 0 <= target_idx < T_sub:
current_labels[target_idx] = max(
current_labels[target_idx].item(),
weight * kernel_val.item(),
)
# Type 5, 6, 7 are Drumroll
elif typ >= 5 and typ <= 7:
exact_frame_end = t_end.item() * fps
exact_start_time_sub = exact_frame_start / TIME_SUB
exact_end_time_sub = exact_frame_end / TIME_SUB
# Improved drumroll body
body_loop_start_idx = math.floor(exact_start_time_sub)
body_loop_end_idx = math.ceil(exact_end_time_sub)
for dr_idx in range(body_loop_start_idx, body_loop_end_idx):
if 0 <= dr_idx < T_sub:
drumroll_labels[dr_idx] = 1.0
# Improved drumroll tail (starts from exact_end_time_sub)
tail_start_points_info = []
rounded_end_time_sub = round(exact_end_time_sub)
if abs(exact_end_time_sub - rounded_end_time_sub) < 1e-6:
idx_single_tail = int(rounded_end_time_sub)
if 0 <= idx_single_tail < T_sub:
tail_start_points_info.append(
{"idx": idx_single_tail, "weight": 1.0}
)
else:
idx_floor_tail = math.floor(exact_end_time_sub)
idx_ceil_tail = idx_floor_tail + 1
frac_tail = exact_end_time_sub - idx_floor_tail
weight_ceil_tail = frac_tail
weight_floor_tail = 1.0 - frac_tail
if weight_floor_tail > 1e-6 and 0 <= idx_floor_tail < T_sub:
tail_start_points_info.append(
{"idx": idx_floor_tail, "weight": weight_floor_tail}
)
if weight_ceil_tail > 1e-6 and 0 <= idx_ceil_tail < T_sub:
tail_start_points_info.append(
{"idx": idx_ceil_tail, "weight": weight_ceil_tail}
)
for point_info in tail_start_points_info:
start_idx = point_info["idx"]
weight = point_info["weight"]
for k_idx, kernel_val in enumerate(tail_kernel):
target_idx = start_idx + k_idx
if 0 <= target_idx < T_sub:
drumroll_labels[target_idx] = max(
drumroll_labels[target_idx].item(),
weight * kernel_val.item(),
)
# Calculate sliding window NPS
note_events = (
[]
) # Store tuples of (time_sec, type_is_drumroll_start_or_end, duration_if_drumroll)
for onset in example[difficulty]:
typ, t_start_tensor, t_end_tensor, *_ = onset
t_start = t_start_tensor.item()
t_end = t_end_tensor.item()
if typ in [1, 2, 3, 4]: # Don or Ka
note_events.append(
(t_start, False, 0)
) # False indicates not a drumroll event, duration 0
elif typ >= 5 and typ <= 7: # Drumroll
note_events.append(
(t_start, True, t_end - t_start)
) # True indicates drumroll start, store duration
# We don't explicitly need a drumroll end event for this calculation method
note_events.sort(key=lambda x: x[0]) # Sort by time
window_duration_seconds = 0.5
# drumroll_nps_rate = 10.0 # Removed: Will use adaptive rate
# Step 1: Calculate base_sliding_nps_labels (Don/Ka only)
base_don_ka_sliding_nps = torch.zeros(T_sub, dtype=torch.float32)
time_step_duration_sec = TIME_SUB / fps # Duration of one T_sub segment
for k_idx in range(T_sub):
k_window_end_sec = ((k_idx + 1) * TIME_SUB) / fps
k_window_start_sec = k_window_end_sec - window_duration_seconds
current_don_ka_count = 0.0
for event_t, is_drumroll, _ in note_events:
if not is_drumroll: # Don or Ka hit
if k_window_start_sec <= event_t < k_window_end_sec:
current_don_ka_count += 1
base_don_ka_sliding_nps[k_idx] = current_don_ka_count / window_duration_seconds
# Step 2: Calculate adaptive_drumroll_rates_for_all_events
adaptive_drumroll_rates_for_all_events = []
for event_t, is_drumroll, drumroll_dur in note_events:
if is_drumroll:
drumroll_start_sec = event_t
drumroll_end_sec = event_t + drumroll_dur
slice_start_idx = math.floor(drumroll_start_sec / time_step_duration_sec)
slice_end_idx = math.ceil(drumroll_end_sec / time_step_duration_sec)
slice_start_idx = max(0, slice_start_idx)
slice_end_idx = min(T_sub, slice_end_idx)
max_nps_in_drumroll_period = 0.0
if slice_start_idx < slice_end_idx:
relevant_base_nps_values = base_don_ka_sliding_nps[
slice_start_idx:slice_end_idx
]
if relevant_base_nps_values.numel() > 0:
max_nps_in_drumroll_period = torch.max(
relevant_base_nps_values
).item()
rate = max(5.0, max_nps_in_drumroll_period)
adaptive_drumroll_rates_for_all_events.append(rate)
else:
adaptive_drumroll_rates_for_all_events.append(0.0) # Placeholder
# Step 3: Calculate final sliding_nps_labels using adaptive rates
# sliding_nps_labels is already initialized with zeros earlier in the function.
for k_idx in range(T_sub):
k_window_end_sec = ((k_idx + 1) * TIME_SUB) / fps
k_window_start_sec = k_window_end_sec - window_duration_seconds
current_window_total_nps_contribution = 0.0
for event_idx, (event_t, is_drumroll, drumroll_dur) in enumerate(note_events):
if is_drumroll:
drumroll_start_sec = event_t
drumroll_end_sec = event_t + drumroll_dur
overlap_start = max(k_window_start_sec, drumroll_start_sec)
overlap_end = min(k_window_end_sec, drumroll_end_sec)
if overlap_end > overlap_start:
overlap_duration = overlap_end - overlap_start
current_adaptive_rate = adaptive_drumroll_rates_for_all_events[
event_idx
]
current_window_total_nps_contribution += (
overlap_duration * current_adaptive_rate
)
else: # Don or Ka hit
if k_window_start_sec <= event_t < k_window_end_sec:
current_window_total_nps_contribution += (
1 # Each hit contributes 1 to the count
)
sliding_nps_labels[k_idx] = (
current_window_total_nps_contribution / window_duration_seconds
)
# Normalize sliding_nps_labels to 0-1 range
if T_sub > 0: # Ensure there are elements to normalize
min_nps_val = torch.min(sliding_nps_labels)
max_nps_val = torch.max(sliding_nps_labels)
denominator = max_nps_val - min_nps_val
if denominator > 1e-6: # Use a small epsilon for float comparison
sliding_nps_labels = (sliding_nps_labels - min_nps_val) / denominator
else:
# If all values are (nearly) the same
if max_nps_val > 1e-6: # If the constant value is positive
sliding_nps_labels = torch.ones_like(sliding_nps_labels)
else: # If the constant value is zero (or very close to it)
sliding_nps_labels = torch.zeros_like(sliding_nps_labels)
duration_seconds = wav_tensor.shape[-1] / SAMPLE_RATE
nps = num_valid_notes / duration_seconds if duration_seconds > 0 else 0.0
parsed = parse_tja(example["tja"], mode=PyParsingMode.Full)
chart = next(
(chart for chart in parsed.charts if chart.course.lower() == difficulty), None
)
difficulty_id = (
0
if difficulty == "easy"
else (
1
if difficulty == "normal"
else 2 if difficulty == "hard" else 3 if difficulty == "oni" else 4
) # Assuming 4 for edit/ura
)
level = chart.level if chart else 0
# --- CNN shape inference and label padding/truncation ---
# Simulate CNN to get output time length (T_cnn)
dummy_model = TaikoConformer7()
with torch.no_grad():
cnn_out = dummy_model.cnn(mel.unsqueeze(0)) # (1, C, F, T_cnn)
_, _, _, T_cnn = cnn_out.shape
# Pad or truncate labels to T_cnn
def pad_or_truncate(label, out_len):
if label.shape[0] < out_len:
pad = torch.zeros(out_len - label.shape[0], dtype=label.dtype)
return torch.cat([label, pad], dim=0)
else:
return label[:out_len]
don_labels = pad_or_truncate(don_labels, T_cnn)
ka_labels = pad_or_truncate(ka_labels, T_cnn)
drumroll_labels = pad_or_truncate(drumroll_labels, T_cnn)
sliding_nps_labels = pad_or_truncate(sliding_nps_labels, T_cnn) # Pad new label
# For conformer input lengths: this should be T_cnn
conformer_sequence_length = T_cnn # This is the actual sequence length after CNN
print(
f"Processed {num_valid_notes} notes in {duration_seconds:.2f} seconds, NPS: {nps:.2f}, Difficulty: {difficulty_id}, Level: {level}"
)
return {
"mel": mel, # (1, N_MELS, T)
"don_labels": don_labels, # (T_cnn,)
"ka_labels": ka_labels, # (T_cnn,)
"drumroll_labels": drumroll_labels, # (T_cnn,)
"sliding_nps_labels": sliding_nps_labels, # Add new label (T_cnn,)
"nps": torch.tensor(nps, dtype=torch.float32),
"difficulty": torch.tensor(difficulty_id, dtype=torch.long),
"level": torch.tensor(level, dtype=torch.long),
"duration_seconds": torch.tensor(duration_seconds, dtype=torch.float32),
"length": torch.tensor(
conformer_sequence_length, dtype=torch.long
), # Use T_cnn for conformer and loss masking
}
def collate_fn(batch):
mels_list = [b["mel"].squeeze(0).transpose(0, 1) for b in batch] # (T, N_MELS)
don_labels_list = [b["don_labels"] for b in batch]
ka_labels_list = [b["ka_labels"] for b in batch]
drumroll_labels_list = [b["drumroll_labels"] for b in batch]
sliding_nps_labels_list = [b["sliding_nps_labels"] for b in batch] # New label list
nps_list = [b["nps"] for b in batch]
difficulty_list = [b["difficulty"] for b in batch]
level_list = [b["level"] for b in batch]
durations_list = [b["duration_seconds"] for b in batch]
lengths_list = [b["length"] for b in batch] # These are T_cnn_i for each example
# Pad mels
padded_mels = nn.utils.rnn.pad_sequence(
mels_list, batch_first=True
) # (B, T_max_mel, N_MELS)
reshaped_mels = padded_mels.transpose(1, 2).unsqueeze(1)
# T_max_mel_batch = padded_mels.shape[1] # Max mel length in batch, not used for label padding anymore
# Determine max sequence length for labels (max T_cnn in batch)
max_label_len = 0
if lengths_list: # handle empty batch case
max_label_len = max(l.item() for l in lengths_list) if lengths_list else 0
# Pad labels to max_label_len (max_t_cnn_in_batch)
def pad_label_to_max_len(label_tensor, target_len):
current_len = label_tensor.shape[0]
if current_len < target_len:
padding_size = target_len - current_len
# Ensure padding is created on the same device as the label_tensor
padding = torch.zeros(
padding_size, dtype=label_tensor.dtype, device=label_tensor.device
)
return torch.cat((label_tensor, padding), dim=0)
elif (
current_len > target_len
): # Should ideally not happen if lengths_list is correct
return label_tensor[:target_len]
return label_tensor
don_labels = torch.stack(
[pad_label_to_max_len(l, max_label_len) for l in don_labels_list]
)
ka_labels = torch.stack(
[pad_label_to_max_len(l, max_label_len) for l in ka_labels_list]
)
drumroll_labels = torch.stack(
[pad_label_to_max_len(l, max_label_len) for l in drumroll_labels_list]
)
sliding_nps_labels = torch.stack(
[pad_label_to_max_len(l, max_label_len) for l in sliding_nps_labels_list]
) # Pad new labels
actual_lengths = torch.tensor([l.item() for l in lengths_list], dtype=torch.long)
return {
"mel": reshaped_mels,
"don_labels": don_labels,
"ka_labels": ka_labels,
"drumroll_labels": drumroll_labels,
"sliding_nps_labels": sliding_nps_labels, # Add new batched labels
"lengths": actual_lengths, # for conformer and loss masking (T_cnn_i for each item)
"nps": torch.stack(nps_list),
"difficulty": torch.stack(difficulty_list),
"level": torch.stack(level_list),
"durations": torch.stack(durations_list),
}