File size: 16,796 Bytes
812b01c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
import math
import torch
import torch.nn as nn
import torchaudio
from torchaudio.transforms import FrequencyMasking
from tja import parse_tja, PyParsingMode
from .config import N_TYPES, SAMPLE_RATE, N_MELS, HOP_LENGTH, TIME_SUB
from .model import TaikoConformer7


mel_transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_mels=N_MELS,
    hop_length=HOP_LENGTH,
    n_fft=2048,
)


freq_mask = FrequencyMasking(freq_mask_param=15)


def preprocess(example, difficulty="oni"):
    wav_tensor = example["audio"]["array"]
    sr = example["audio"]["sampling_rate"]
    # 1) load & resample
    if sr != SAMPLE_RATE:
        wav_tensor = torchaudio.functional.resample(wav_tensor, sr, SAMPLE_RATE)
    # normalize audio
    wav_tensor = wav_tensor / (wav_tensor.abs().max() + 1e-8)
    # add random Gaussian noise
    if torch.rand(1).item() < 0.5:
        wav_tensor = wav_tensor + 0.005 * torch.randn_like(wav_tensor)
    # 2) mel: (1, N_MELS, T)
    mel = mel_transform(wav_tensor).unsqueeze(0)
    # apply SpecAugment
    mel = freq_mask(mel)
    _, _, T = mel.shape
    # 3) build label sequence of length ceil(T / TIME_SUB)
    T_sub = math.ceil(T / TIME_SUB)

    # Initialize energy-based labels for Don, Ka, Drumroll
    don_labels = torch.zeros(T_sub, dtype=torch.float32)
    ka_labels = torch.zeros(T_sub, dtype=torch.float32)
    drumroll_labels = torch.zeros(T_sub, dtype=torch.float32)
    sliding_nps_labels = torch.zeros(
        T_sub, dtype=torch.float32
    )  # New label for sliding NPS

    # Define exponential decay tail parameters
    tail_length = 40  # number of frames for decay tail
    decay_rate = 8.0  # decay rate parameter, adjust as needed
    tail_kernel = torch.exp(
        -torch.arange(0, tail_length, dtype=torch.float32) / decay_rate
    )

    fps = SAMPLE_RATE / HOP_LENGTH
    num_valid_notes = 0
    for onset in example[difficulty]:
        typ, t_start, t_end, *_ = onset

        # Assuming N_TYPES in config is appropriately set (e.g., 7 or more)
        if typ < 1 or typ > N_TYPES:  # Filter out invalid types
            continue

        num_valid_notes += 1

        exact_frame_start = t_start.item() * fps

        # Type 1 and 3 are Don, Type 2 and 4 are Ka
        if typ == 1 or typ == 3 or typ == 2 or typ == 4:
            exact_hit_time_sub = exact_frame_start / TIME_SUB

            current_labels = don_labels if (typ == 1 or typ == 3) else ka_labels

            start_points_info = []
            rounded_hit_time_sub = round(exact_hit_time_sub)

            if (
                abs(exact_hit_time_sub - rounded_hit_time_sub) < 1e-6
            ):  # Tolerance for float precision
                idx_single = int(rounded_hit_time_sub)
                if 0 <= idx_single < T_sub:
                    start_points_info.append({"idx": idx_single, "weight": 1.0})
            else:
                idx_floor = math.floor(exact_hit_time_sub)
                idx_ceil = idx_floor + 1

                frac = exact_hit_time_sub - idx_floor
                weight_ceil = frac
                weight_floor = 1.0 - frac

                if weight_floor > 1e-6 and 0 <= idx_floor < T_sub:
                    start_points_info.append({"idx": idx_floor, "weight": weight_floor})
                if weight_ceil > 1e-6 and 0 <= idx_ceil < T_sub:
                    start_points_info.append({"idx": idx_ceil, "weight": weight_ceil})

            for point_info in start_points_info:
                start_idx = point_info["idx"]
                weight = point_info["weight"]
                for k_idx, kernel_val in enumerate(tail_kernel):
                    target_idx = start_idx + k_idx
                    if 0 <= target_idx < T_sub:
                        current_labels[target_idx] = max(
                            current_labels[target_idx].item(),
                            weight * kernel_val.item(),
                        )

        # Type 5, 6, 7 are Drumroll
        elif typ >= 5 and typ <= 7:
            exact_frame_end = t_end.item() * fps
            exact_start_time_sub = exact_frame_start / TIME_SUB
            exact_end_time_sub = exact_frame_end / TIME_SUB

            # Improved drumroll body
            body_loop_start_idx = math.floor(exact_start_time_sub)
            body_loop_end_idx = math.ceil(exact_end_time_sub)

            for dr_idx in range(body_loop_start_idx, body_loop_end_idx):
                if 0 <= dr_idx < T_sub:
                    drumroll_labels[dr_idx] = 1.0

            # Improved drumroll tail (starts from exact_end_time_sub)
            tail_start_points_info = []
            rounded_end_time_sub = round(exact_end_time_sub)
            if abs(exact_end_time_sub - rounded_end_time_sub) < 1e-6:
                idx_single_tail = int(rounded_end_time_sub)
                if 0 <= idx_single_tail < T_sub:
                    tail_start_points_info.append(
                        {"idx": idx_single_tail, "weight": 1.0}
                    )
            else:
                idx_floor_tail = math.floor(exact_end_time_sub)
                idx_ceil_tail = idx_floor_tail + 1

                frac_tail = exact_end_time_sub - idx_floor_tail
                weight_ceil_tail = frac_tail
                weight_floor_tail = 1.0 - frac_tail

                if weight_floor_tail > 1e-6 and 0 <= idx_floor_tail < T_sub:
                    tail_start_points_info.append(
                        {"idx": idx_floor_tail, "weight": weight_floor_tail}
                    )
                if weight_ceil_tail > 1e-6 and 0 <= idx_ceil_tail < T_sub:
                    tail_start_points_info.append(
                        {"idx": idx_ceil_tail, "weight": weight_ceil_tail}
                    )

            for point_info in tail_start_points_info:
                start_idx = point_info["idx"]
                weight = point_info["weight"]
                for k_idx, kernel_val in enumerate(tail_kernel):
                    target_idx = start_idx + k_idx
                    if 0 <= target_idx < T_sub:
                        drumroll_labels[target_idx] = max(
                            drumroll_labels[target_idx].item(),
                            weight * kernel_val.item(),
                        )

    # Calculate sliding window NPS
    note_events = (
        []
    )  # Store tuples of (time_sec, type_is_drumroll_start_or_end, duration_if_drumroll)
    for onset in example[difficulty]:
        typ, t_start_tensor, t_end_tensor, *_ = onset
        t_start = t_start_tensor.item()
        t_end = t_end_tensor.item()

        if typ in [1, 2, 3, 4]:  # Don or Ka
            note_events.append(
                (t_start, False, 0)
            )  # False indicates not a drumroll event, duration 0
        elif typ >= 5 and typ <= 7:  # Drumroll
            note_events.append(
                (t_start, True, t_end - t_start)
            )  # True indicates drumroll start, store duration
            # We don't explicitly need a drumroll end event for this calculation method

    note_events.sort(key=lambda x: x[0])  # Sort by time

    window_duration_seconds = 0.5
    # drumroll_nps_rate = 10.0 # Removed: Will use adaptive rate

    # Step 1: Calculate base_sliding_nps_labels (Don/Ka only)
    base_don_ka_sliding_nps = torch.zeros(T_sub, dtype=torch.float32)
    time_step_duration_sec = TIME_SUB / fps  # Duration of one T_sub segment

    for k_idx in range(T_sub):
        k_window_end_sec = ((k_idx + 1) * TIME_SUB) / fps
        k_window_start_sec = k_window_end_sec - window_duration_seconds

        current_don_ka_count = 0.0
        for event_t, is_drumroll, _ in note_events:
            if not is_drumroll:  # Don or Ka hit
                if k_window_start_sec <= event_t < k_window_end_sec:
                    current_don_ka_count += 1
        base_don_ka_sliding_nps[k_idx] = current_don_ka_count / window_duration_seconds

    # Step 2: Calculate adaptive_drumroll_rates_for_all_events
    adaptive_drumroll_rates_for_all_events = []
    for event_t, is_drumroll, drumroll_dur in note_events:
        if is_drumroll:
            drumroll_start_sec = event_t
            drumroll_end_sec = event_t + drumroll_dur

            slice_start_idx = math.floor(drumroll_start_sec / time_step_duration_sec)
            slice_end_idx = math.ceil(drumroll_end_sec / time_step_duration_sec)

            slice_start_idx = max(0, slice_start_idx)
            slice_end_idx = min(T_sub, slice_end_idx)

            max_nps_in_drumroll_period = 0.0
            if slice_start_idx < slice_end_idx:
                relevant_base_nps_values = base_don_ka_sliding_nps[
                    slice_start_idx:slice_end_idx
                ]
                if relevant_base_nps_values.numel() > 0:
                    max_nps_in_drumroll_period = torch.max(
                        relevant_base_nps_values
                    ).item()

            rate = max(5.0, max_nps_in_drumroll_period)
            adaptive_drumroll_rates_for_all_events.append(rate)
        else:
            adaptive_drumroll_rates_for_all_events.append(0.0)  # Placeholder

    # Step 3: Calculate final sliding_nps_labels using adaptive rates
    # sliding_nps_labels is already initialized with zeros earlier in the function.
    for k_idx in range(T_sub):
        k_window_end_sec = ((k_idx + 1) * TIME_SUB) / fps
        k_window_start_sec = k_window_end_sec - window_duration_seconds

        current_window_total_nps_contribution = 0.0
        for event_idx, (event_t, is_drumroll, drumroll_dur) in enumerate(note_events):
            if is_drumroll:
                drumroll_start_sec = event_t
                drumroll_end_sec = event_t + drumroll_dur

                overlap_start = max(k_window_start_sec, drumroll_start_sec)
                overlap_end = min(k_window_end_sec, drumroll_end_sec)

                if overlap_end > overlap_start:
                    overlap_duration = overlap_end - overlap_start
                    current_adaptive_rate = adaptive_drumroll_rates_for_all_events[
                        event_idx
                    ]
                    current_window_total_nps_contribution += (
                        overlap_duration * current_adaptive_rate
                    )
            else:  # Don or Ka hit
                if k_window_start_sec <= event_t < k_window_end_sec:
                    current_window_total_nps_contribution += (
                        1  # Each hit contributes 1 to the count
                    )

        sliding_nps_labels[k_idx] = (
            current_window_total_nps_contribution / window_duration_seconds
        )

    # Normalize sliding_nps_labels to 0-1 range
    if T_sub > 0:  # Ensure there are elements to normalize
        min_nps_val = torch.min(sliding_nps_labels)
        max_nps_val = torch.max(sliding_nps_labels)
        denominator = max_nps_val - min_nps_val
        if denominator > 1e-6:  # Use a small epsilon for float comparison
            sliding_nps_labels = (sliding_nps_labels - min_nps_val) / denominator
        else:
            # If all values are (nearly) the same
            if max_nps_val > 1e-6:  # If the constant value is positive
                sliding_nps_labels = torch.ones_like(sliding_nps_labels)
            else:  # If the constant value is zero (or very close to it)
                sliding_nps_labels = torch.zeros_like(sliding_nps_labels)

    duration_seconds = wav_tensor.shape[-1] / SAMPLE_RATE
    nps = num_valid_notes / duration_seconds if duration_seconds > 0 else 0.0

    parsed = parse_tja(example["tja"], mode=PyParsingMode.Full)
    chart = next(
        (chart for chart in parsed.charts if chart.course.lower() == difficulty), None
    )
    difficulty_id = (
        0
        if difficulty == "easy"
        else (
            1
            if difficulty == "normal"
            else 2 if difficulty == "hard" else 3 if difficulty == "oni" else 4
        )  # Assuming 4 for edit/ura
    )
    level = chart.level if chart else 0

    # --- CNN shape inference and label padding/truncation ---
    # Simulate CNN to get output time length (T_cnn)
    dummy_model = TaikoConformer7()
    with torch.no_grad():
        cnn_out = dummy_model.cnn(mel.unsqueeze(0))  # (1, C, F, T_cnn)
        _, _, _, T_cnn = cnn_out.shape

    # Pad or truncate labels to T_cnn
    def pad_or_truncate(label, out_len):
        if label.shape[0] < out_len:
            pad = torch.zeros(out_len - label.shape[0], dtype=label.dtype)
            return torch.cat([label, pad], dim=0)
        else:
            return label[:out_len]

    don_labels = pad_or_truncate(don_labels, T_cnn)
    ka_labels = pad_or_truncate(ka_labels, T_cnn)
    drumroll_labels = pad_or_truncate(drumroll_labels, T_cnn)
    sliding_nps_labels = pad_or_truncate(sliding_nps_labels, T_cnn)  # Pad new label

    # For conformer input lengths: this should be T_cnn
    conformer_sequence_length = T_cnn  # This is the actual sequence length after CNN

    print(
        f"Processed {num_valid_notes} notes in {duration_seconds:.2f} seconds, NPS: {nps:.2f}, Difficulty: {difficulty_id}, Level: {level}"
    )

    return {
        "mel": mel,  # (1, N_MELS, T)
        "don_labels": don_labels,  # (T_cnn,)
        "ka_labels": ka_labels,  # (T_cnn,)
        "drumroll_labels": drumroll_labels,  # (T_cnn,)
        "sliding_nps_labels": sliding_nps_labels,  # Add new label (T_cnn,)
        "nps": torch.tensor(nps, dtype=torch.float32),
        "difficulty": torch.tensor(difficulty_id, dtype=torch.long),
        "level": torch.tensor(level, dtype=torch.long),
        "duration_seconds": torch.tensor(duration_seconds, dtype=torch.float32),
        "length": torch.tensor(
            conformer_sequence_length, dtype=torch.long
        ),  # Use T_cnn for conformer and loss masking
    }


def collate_fn(batch):
    mels_list = [b["mel"].squeeze(0).transpose(0, 1) for b in batch]  # (T, N_MELS)
    don_labels_list = [b["don_labels"] for b in batch]
    ka_labels_list = [b["ka_labels"] for b in batch]
    drumroll_labels_list = [b["drumroll_labels"] for b in batch]
    sliding_nps_labels_list = [b["sliding_nps_labels"] for b in batch]  # New label list
    nps_list = [b["nps"] for b in batch]
    difficulty_list = [b["difficulty"] for b in batch]
    level_list = [b["level"] for b in batch]
    durations_list = [b["duration_seconds"] for b in batch]
    lengths_list = [b["length"] for b in batch]  # These are T_cnn_i for each example

    # Pad mels
    padded_mels = nn.utils.rnn.pad_sequence(
        mels_list, batch_first=True
    )  # (B, T_max_mel, N_MELS)
    reshaped_mels = padded_mels.transpose(1, 2).unsqueeze(1)
    # T_max_mel_batch = padded_mels.shape[1] # Max mel length in batch, not used for label padding anymore

    # Determine max sequence length for labels (max T_cnn in batch)
    max_label_len = 0
    if lengths_list:  # handle empty batch case
        max_label_len = max(l.item() for l in lengths_list) if lengths_list else 0

    # Pad labels to max_label_len (max_t_cnn_in_batch)
    def pad_label_to_max_len(label_tensor, target_len):
        current_len = label_tensor.shape[0]
        if current_len < target_len:
            padding_size = target_len - current_len
            # Ensure padding is created on the same device as the label_tensor
            padding = torch.zeros(
                padding_size, dtype=label_tensor.dtype, device=label_tensor.device
            )
            return torch.cat((label_tensor, padding), dim=0)
        elif (
            current_len > target_len
        ):  # Should ideally not happen if lengths_list is correct
            return label_tensor[:target_len]
        return label_tensor

    don_labels = torch.stack(
        [pad_label_to_max_len(l, max_label_len) for l in don_labels_list]
    )
    ka_labels = torch.stack(
        [pad_label_to_max_len(l, max_label_len) for l in ka_labels_list]
    )
    drumroll_labels = torch.stack(
        [pad_label_to_max_len(l, max_label_len) for l in drumroll_labels_list]
    )
    sliding_nps_labels = torch.stack(
        [pad_label_to_max_len(l, max_label_len) for l in sliding_nps_labels_list]
    )  # Pad new labels

    actual_lengths = torch.tensor([l.item() for l in lengths_list], dtype=torch.long)

    return {
        "mel": reshaped_mels,
        "don_labels": don_labels,
        "ka_labels": ka_labels,
        "drumroll_labels": drumroll_labels,
        "sliding_nps_labels": sliding_nps_labels,  # Add new batched labels
        "lengths": actual_lengths,  # for conformer and loss masking (T_cnn_i for each item)
        "nps": torch.stack(nps_list),
        "difficulty": torch.stack(difficulty_list),
        "level": torch.stack(level_list),
        "durations": torch.stack(durations_list),
    }