File size: 8,283 Bytes
812b01c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import math
import torch
import torch.nn as nn
import torchaudio
from torchaudio.transforms import FrequencyMasking
from .config import N_TYPES, SAMPLE_RATE, N_MELS, HOP_LENGTH, TIME_SUB
from .model import TaikoConformer5


mel_transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_mels=N_MELS,
    hop_length=HOP_LENGTH,
    n_fft=2048,
)


freq_mask = FrequencyMasking(freq_mask_param=15)


def preprocess(example, difficulty="oni"):
    wav_tensor = example["audio"]["array"]
    sr = example["audio"]["sampling_rate"]
    # 1) load & resample
    if sr != SAMPLE_RATE:
        wav_tensor = torchaudio.functional.resample(wav_tensor, sr, SAMPLE_RATE)
    # normalize audio
    wav_tensor = wav_tensor / (wav_tensor.abs().max() + 1e-8)
    # add random Gaussian noise
    if torch.rand(1).item() < 0.5:
        wav_tensor = wav_tensor + 0.005 * torch.randn_like(wav_tensor)
    # 2) mel: (1, N_MELS, T)
    mel = mel_transform(wav_tensor).unsqueeze(0)
    # apply SpecAugment
    # we don't use time masking since we don't want model to predict notes when they are masked
    mel = freq_mask(mel)
    _, _, T = mel.shape
    # 3) build label sequence of length ceil(T / TIME_SUB)
    T_sub = math.ceil(T / TIME_SUB)

    # Initialize energy-based labels for Don, Ka, Drumroll
    don_labels = torch.zeros(T_sub, dtype=torch.float32)
    ka_labels = torch.zeros(T_sub, dtype=torch.float32)
    drumroll_labels = torch.zeros(T_sub, dtype=torch.float32)

    # Define exponential decay tail parameters
    tail_length = 40  # number of frames for decay tail
    decay_rate = 8.0  # decay rate parameter, adjust as needed
    tail_kernel = torch.exp(
        -torch.arange(0, tail_length, dtype=torch.float32) / decay_rate
    )

    fps = SAMPLE_RATE / HOP_LENGTH
    num_valid_notes = 0
    for onset in example[difficulty]:
        typ, t_start, t_end, *_ = onset

        # Assuming N_TYPES in config is appropriately set (e.g., 7 or more)
        if typ < 1 or typ > N_TYPES:  # Filter out invalid types
            continue

        num_valid_notes += 1

        f = int(round(t_start.item() * fps))
        idx = f // TIME_SUB
        if 0 <= idx < T_sub:
            # Apply exponential decay kernel to the corresponding energy channel
            # Type 1 and 3 are Don
            if typ == 1 or typ == 3:
                for i, val in enumerate(tail_kernel):
                    target_idx = idx + i
                    if 0 <= target_idx < T_sub:
                        don_labels[target_idx] = max(
                            don_labels[target_idx].item(), val.item()
                        )
            # Type 2 and 4 are Ka
            elif typ == 2 or typ == 4:
                for i, val in enumerate(tail_kernel):
                    target_idx = idx + i
                    if 0 <= target_idx < T_sub:
                        ka_labels[target_idx] = max(
                            ka_labels[target_idx].item(), val.item()
                        )
            # Type 5, 6, 7 are Drumroll
            elif typ >= 5 and typ <= 7:
                f_end = int(round(t_end.item() * fps))
                idx_end = f_end // TIME_SUB

                for dr in range(idx, idx_end):
                    if 0 <= dr < T_sub:
                        drumroll_labels[dr] = 1.0

                for i, val in enumerate(tail_kernel):
                    target_idx = idx_end + i
                    if 0 <= target_idx < T_sub:
                        drumroll_labels[target_idx] = max(
                            drumroll_labels[target_idx].item(), val.item()
                        )

    duration_seconds = wav_tensor.shape[-1] / SAMPLE_RATE
    nps = num_valid_notes / duration_seconds if duration_seconds > 0 else 0.0
    print(
        f"Processed {num_valid_notes} notes in {duration_seconds:.2f} seconds, NPS: {nps:.2f}"
    )

    return {
        "mel": mel,
        "don_labels": don_labels,
        "ka_labels": ka_labels,
        "drumroll_labels": drumroll_labels,
        "nps": torch.tensor(nps, dtype=torch.float32),
        "duration_seconds": torch.tensor(duration_seconds, dtype=torch.float32),
    }


def collate_fn(batch):
    mels_list = [b["mel"].squeeze(0).transpose(0, 1) for b in batch]  # (T, N_MELS)
    # Extract new energy-based labels
    don_labels_list = [b["don_labels"] for b in batch]
    ka_labels_list = [b["ka_labels"] for b in batch]
    drumroll_labels_list = [b["drumroll_labels"] for b in batch]

    nps_list = [b["nps"] for b in batch]  # Extract NPS
    durations_list = [b["duration_seconds"] for b in batch]  # Extract durations

    # Pad mels
    padded_mels = nn.utils.rnn.pad_sequence(
        mels_list, batch_first=True
    )  # (B, T_max, N_MELS)
    # Reshape for CNN: (B, 1, N_MELS, T_max)
    reshaped_mels = padded_mels.transpose(1, 2).unsqueeze(1)

    # Simulate CNN time downsampling to get output lengths
    dummy_model_for_shape_inference = TaikoConformer5()
    dummy_cnn = dummy_model_for_shape_inference.cnn
    with torch.no_grad():
        cnn_out = dummy_cnn(reshaped_mels)  # Use reshaped_mels that has batch dim
        _, _, _, T_cnn = cnn_out.shape

    padded_don_labels = []
    padded_ka_labels = []
    padded_drumroll_labels = []
    # lengths = [] # This was for original presence/type labels, conformer_input_lengths is used for model

    for i in range(len(batch)):
        d_labels = don_labels_list[i]
        k_labels = ka_labels_list[i]
        dr_labels = drumroll_labels_list[i]

        item_original_T_sub = d_labels.shape[
            0
        ]  # Assuming all label types have same original length
        out_len = T_cnn  # Target length for labels is T_cnn

        # Pad or truncate don_labels
        if item_original_T_sub < out_len:
            pad_d = torch.full(
                (out_len - item_original_T_sub,),
                0,  # Pad with 0 for energy labels
                dtype=d_labels.dtype,
                device=d_labels.device,
            )
            padded_d = torch.cat([d_labels, pad_d], dim=0)
        else:
            padded_d = d_labels[:out_len]
        padded_don_labels.append(padded_d)

        # Pad or truncate ka_labels
        if item_original_T_sub < out_len:
            pad_k = torch.full(
                (out_len - item_original_T_sub,),
                0,  # Pad with 0 for energy labels
                dtype=k_labels.dtype,
                device=k_labels.device,
            )
            padded_k = torch.cat([k_labels, pad_k], dim=0)
        else:
            padded_k = k_labels[:out_len]
        padded_ka_labels.append(padded_k)

        # Pad or truncate drumroll_labels
        if item_original_T_sub < out_len:
            pad_dr = torch.full(
                (out_len - item_original_T_sub,),
                0,  # Pad with 0 for energy labels
                dtype=dr_labels.dtype,
                device=dr_labels.device,
            )
            padded_dr = torch.cat([dr_labels, pad_dr], dim=0)
        else:
            padded_dr = dr_labels[:out_len]
        padded_drumroll_labels.append(padded_dr)

    # For Conformer input lengths: lengths of mel sequences after CNN subsampling
    # (Assuming CNN does not subsample in time, T_cnn is effectively T_mel_padded)
    # The `lengths` for the Conformer should be based on the mel input to the conformer part.
    # The existing calculation for conformer_input_lengths seems to relate to TIME_SUB.
    # If the Conformer input itself is not subsampled by TIME_SUB, this might need review.
    # For now, keeping the existing conformer_input_lengths logic as it's outside the scope of label change.
    conformer_input_lengths = [
        math.ceil(mels_list[i].shape[0] / TIME_SUB) for i in range(len(batch))
    ]
    conformer_input_lengths = torch.tensor(
        [min(l, T_cnn) for l in conformer_input_lengths], dtype=torch.long
    )

    return {
        "mel": reshaped_mels,
        "don_labels": torch.stack(padded_don_labels),
        "ka_labels": torch.stack(padded_ka_labels),
        "drumroll_labels": torch.stack(padded_drumroll_labels),
        "lengths": conformer_input_lengths,  # These are for the Conformer model
        "nps": torch.stack(nps_list),
        "durations": torch.stack(durations_list),
    }