File size: 13,213 Bytes
812b01c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db8b2d5
812b01c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db8b2d5
812b01c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
import time
import torch
import torchaudio
import matplotlib.pyplot as plt
import numpy as np
from .config import SAMPLE_RATE, N_MELS, HOP_LENGTH
import torch.profiler


# --- PREPROCESSING (match training) ---
def preprocess_audio(audio_path):
    wav, sr = torchaudio.load(audio_path)
    wav = wav.mean(dim=0)  # mono
    if sr != SAMPLE_RATE:
        wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
    wav = wav / (wav.abs().max() + 1e-8)  # Normalize audio

    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_mels=N_MELS,
        hop_length=HOP_LENGTH,
        n_fft=2048,
    )
    mel = mel_transform(wav)
    return mel  # mel is (N_MELS, T_mel)


# --- INFERENCE ---
def run_inference(model, mel_input, nps_input, difficulty_input, level_input, device):
    model.eval()
    with torch.no_grad():
        mel = mel_input.to(device).unsqueeze(0)  # (1, N_MELS, T_mel)
        nps = nps_input.to(device).unsqueeze(0)  # (1,)
        difficulty = difficulty_input.to(device).unsqueeze(0)  # (1,)
        level = level_input.to(device).unsqueeze(0)  # (1,)

        mel_cnn_input = mel.unsqueeze(1)  # (1, 1, N_MELS, T_mel)

        conformer_lengths = torch.tensor(
            [mel_cnn_input.shape[-1]], dtype=torch.long, device=device
        )

        with torch.profiler.profile(
            activities=[
                torch.profiler.ProfilerActivity.CPU,
                *(
                    [torch.profiler.ProfilerActivity.CUDA]
                    if device.type == "cuda"
                    else []
                ),
            ],
            record_shapes=True,
            profile_memory=True,
            with_stack=False,
            with_flops=True,
        ) as prof:
            out_dict = model(mel_cnn_input, conformer_lengths, nps, difficulty, level)
        print(
            prof.key_averages().table(
                sort_by=(
                    "self_cuda_memory_usage"
                    if device.type == "cuda"
                    else "self_cpu_time_total"
                ),
                row_limit=20,
            )
        )

        energies = out_dict["presence"].squeeze(0).cpu().numpy()

        don_energy = energies[:, 0]
        ka_energy = energies[:, 1]
        drumroll_energy = energies[:, 2]

    return don_energy, ka_energy, drumroll_energy


# --- DECODE TO ONSETS ---
def decode_onsets(
    don_energy,
    ka_energy,
    drumroll_energy,
    hop_sec,
    threshold=0.5,
    min_distance_frames=3,
):
    results = []
    T_out = len(don_energy)
    last_onset_frame = -min_distance_frames

    for i in range(1, T_out - 1):  # Iterate considering neighbors for peak detection
        if i < last_onset_frame + min_distance_frames:
            continue

        e_don, e_ka, e_drum = don_energy[i], ka_energy[i], drumroll_energy[i]
        energies_at_i = {
            1: e_don,
            2: e_ka,
            5: e_drum,
        }  # Type mapping: 1:Don, 2:Ka, 5:Drumroll

        # Find which energy is max and if it's a peak above threshold
        # Sort by energy value descending to prioritize higher energy in case of ties for peak condition
        sorted_types_by_energy = sorted(
            energies_at_i.keys(), key=lambda x: energies_at_i[x], reverse=True
        )

        detected_this_frame = False
        for onset_type in sorted_types_by_energy:
            current_energy_series = None
            if onset_type == 1:
                current_energy_series = don_energy
            elif onset_type == 2:
                current_energy_series = ka_energy
            elif onset_type == 5:
                current_energy_series = drumroll_energy

            energy_val = current_energy_series[i]

            if (
                energy_val > threshold
                and energy_val > current_energy_series[i - 1]
                and energy_val > current_energy_series[i + 1]
            ):
                # Check if this energy is the highest among the three at this frame
                # This check is implicitly handled by iterating `sorted_types_by_energy`
                # and breaking after the first detection.
                results.append((i * hop_sec, onset_type))
                last_onset_frame = i
                detected_this_frame = True
                break  # Only one onset type per frame

    return results


# --- VISUALIZATION ---
def plot_results(
    mel_spectrogram,
    don_energy,
    ka_energy,
    drumroll_energy,
    onsets,
    hop_sec,
    out_path=None,
):
    # mel_spectrogram is (N_MELS, T_mel)
    T_mel = mel_spectrogram.shape[1]
    T_out = len(don_energy)  # Length of energy arrays (model output time dimension)

    # Time axes
    time_axis_mel = np.arange(T_mel) * (HOP_LENGTH / SAMPLE_RATE)
    # hop_sec for model output is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE
    # However, the model output T_out is related to T_mel (input to CNN).
    # If CNN does not change time dimension, T_out = T_mel.
    # If TIME_SUB is used for label generation, T_out = T_mel / TIME_SUB.
    # The `lengths` passed to conformer in `run_inference` is T_mel.
    # The output of conformer `x` has shape (B, T, D_MODEL). This T is `lengths`.
    # So, T_out from model is T_mel.
    # The `hop_sec` for onsets should be based on the model output frame rate.
    # If model output T_out corresponds to T_mel, then hop_sec for plotting energies is HOP_LENGTH / SAMPLE_RATE.
    # The `hop_sec` passed to `decode_onsets` is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE.
    # This seems inconsistent. Let's clarify: `TIME_SUB` is used in `preprocess.py` to determine `T_sub` for labels.
    # The model's CNN output time dimension T_cnn is used to pad/truncate labels in `collate_fn`.
    # In `model.py`, the CNN does not stride in time. So T_cnn_out = T_mel_input_to_CNN.
    # The `lengths` for the conformer is based on this T_cnn_out.
    # So, the output of the regressor `presence` (B, T_cnn_out, 3) has T_cnn_out time steps.
    # Each step corresponds to `HOP_LENGTH / SAMPLE_RATE` seconds if TIME_SUB is not involved in downsampling features for the conformer.
    # Let's assume the `hop_sec` used for `decode_onsets` is correct for interpreting model output frames.
    time_axis_energies = np.arange(T_out) * hop_sec

    fig, ax1 = plt.subplots(figsize=(100, 10))

    # Plot Mel Spectrogram on ax1
    mel_db = torchaudio.functional.amplitude_to_DB(
        mel_spectrogram, multiplier=10.0, amin=1e-10, db_multiplier=0.0
    )
    img = ax1.imshow(
        mel_db.numpy(),
        aspect="auto",
        origin="lower",
        cmap="magma",
        extent=[time_axis_mel[0], time_axis_mel[-1], 0, N_MELS],
    )
    ax1.set_title("Mel Spectrogram with Predicted Energies and Onsets")
    ax1.set_xlabel("Time (s)")
    ax1.set_ylabel("Mel Bin")
    fig.colorbar(img, ax=ax1, format="%+2.0f dB")

    # Create a second y-axis for energies
    ax2 = ax1.twinx()
    ax2.plot(time_axis_energies, don_energy, label="Don Energy", color="red")
    ax2.plot(time_axis_energies, ka_energy, label="Ka Energy", color="blue")
    ax2.plot(
        time_axis_energies, drumroll_energy, label="Drumroll Energy", color="green"
    )
    ax2.set_ylabel("Energy")
    ax2.set_ylim(0, 1.2)  # Assuming energies are somewhat normalized or bounded

    # Overlay onsets from decode_onsets (t is already in seconds)
    labeled_types = set()
    # Group drumrolls into segments (reuse logic from write_tja)
    drumroll_times = [t_sec for t_sec, typ in onsets if typ == 5]
    drumroll_times.sort()
    drumroll_segments = []
    if drumroll_times:
        seg_start = drumroll_times[0]
        prev = drumroll_times[0]
        for t in drumroll_times[1:]:
            if t - prev <= hop_sec * 6:  # up to 5-frame gap
                prev = t
            else:
                drumroll_segments.append((seg_start, prev))
                seg_start = t
                prev = t
        drumroll_segments.append((seg_start, prev))
    # Plot Don/Ka onsets as vertical lines
    for t_sec, typ in onsets:
        if typ == 5:
            continue  # skip drumroll onsets
        color_map = {1: "darkred", 2: "darkblue"}
        label_map = {1: "Don Onset", 2: "Ka Onset"}
        line_color = color_map.get(typ, "black")
        line_label = label_map.get(typ, f"Type {typ} Onset")
        if typ not in labeled_types:
            ax1.axvline(
                t_sec, color=line_color, linestyle="--", alpha=0.9, label=line_label
            )
            labeled_types.add(typ)
        else:
            ax1.axvline(t_sec, color=line_color, linestyle="--", alpha=0.9)
    # Plot drumroll segments as shaded regions
    for seg_start, seg_end in drumroll_segments:
        ax1.axvspan(
            seg_start,
            seg_end + hop_sec,
            color="green",
            alpha=0.2,
            label="Drumroll Segment" if "drumroll" not in labeled_types else None,
        )
        labeled_types.add("drumroll")

    # Combine legends from both axes
    lines, labels = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax2.legend(lines + lines2, labels + labels2, loc="upper right")

    fig.tight_layout()

    # Return plot as image buffer or save to file if path provided
    if out_path:
        plt.savefig(out_path)
        print(f"Saved plot to {out_path}")
        plt.close(fig)
        return out_path
    else:
        # Return plot as in-memory buffer
        return fig


def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav", offset=0):
    # TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd
    # Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single)
    sec_per_beat = 60 / bpm
    beats_per_measure = 4  # Assuming 4/4 time signature
    sec_per_measure = sec_per_beat * beats_per_measure
    # Step 1: Map onsets to (measure_idx, slot, typ)
    slot_events = []
    for t, typ in onsets:
        measure_idx = int(t // sec_per_measure)
        t_in_measure = t % sec_per_measure
        slot = int(round(t_in_measure / sec_per_measure * quantize))
        if slot >= quantize:
            slot = quantize - 1
        slot_events.append((measure_idx, slot, typ))
    # Step 2: Build measure/slot grid
    if slot_events:
        max_measure_idx = max(m for m, _, _ in slot_events)
    else:
        max_measure_idx = -1
    measures = {i: [0] * quantize for i in range(max_measure_idx + 1)}
    # Step 3: Place Don/Ka, collect drumrolls
    drumroll_slots = set()
    for m, s, typ in slot_events:
        if typ in [1, 2]:
            measures[m][s] = typ
        elif typ == 5:
            drumroll_slots.add((m, s))
    # Step 4: Process drumrolls into contiguous regions, mark 5 (start) and 8 (end)
    # Flatten all slots to a list of (measure, slot) sorted
    drumroll_list = sorted(list(drumroll_slots))
    # Group into contiguous regions (allowing a gap of 5 slots)
    grouped = []
    group = []
    for ms in drumroll_list:
        if not group:
            group = [ms]
        else:
            last_m, last_s = group[-1]
            m, s = ms
            # Calculate slot distance, considering measure wrap
            slot_dist = None
            if m == last_m:
                slot_dist = s - last_s
            elif m == last_m + 1 and last_s <= quantize - 1:
                slot_dist = (quantize - 1 - last_s) + s + 1
            else:
                slot_dist = None
            # Allow gap of up to 5 slots (slot_dist <= 6)
            if slot_dist is not None and 1 <= slot_dist <= 6:
                group.append(ms)
            else:
                grouped.append(group)
                group = [ms]
    if group:
        grouped.append(group)
    # Mark 5 (start) and 8 (end) for each group
    for region in grouped:
        if len(region) == 1:
            m, s = region[0]
            measures[m][s] = 5
            # Place 8 in next slot (or next measure if at end)
            if s < quantize - 1:
                measures[m][s + 1] = 8
            elif m < max_measure_idx:
                measures[m + 1][0] = 8
        else:
            m_start, s_start = region[0]
            m_end, s_end = region[-1]
            measures[m_start][s_start] = 5
            measures[m_end][s_end] = 8
            # Fill 0 for middle slots (already 0 by default)
    # Step 5: Generate TJA content
    tja_content = []
    tja_content.append(f"TITLE:{audio} (TC6, {time.strftime('%Y-%m-%d %H:%M:%S')})")
    tja_content.append(f"BPM:{bpm}")
    tja_content.append(f"WAVE:{audio}")
    tja_content.append(f"OFFSET:{offset}")
    tja_content.append("COURSE:Oni\nLEVEL:9\n")
    tja_content.append("#START")
    for i in range(max_measure_idx + 1):
        notes = measures.get(i, [0] * quantize)
        line = "".join(str(n) for n in notes)
        tja_content.append(line + ",")
    tja_content.append("#END")

    tja_string = "\n".join(tja_content)

    # If out_path is provided, also write to file
    if out_path:
        with open(out_path, "w", encoding="utf-8") as f:
            f.write(tja_string)
        print(f"TJA chart saved to {out_path}")

    return tja_string