import time import torch import torchaudio import matplotlib.pyplot as plt import numpy as np from .config import SAMPLE_RATE, N_MELS, HOP_LENGTH import torch.profiler # --- PREPROCESSING (match training) --- def preprocess_audio(audio_path, nps=5.0): wav, sr = torchaudio.load(audio_path) wav = wav.mean(dim=0) # mono if sr != SAMPLE_RATE: wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE) wav = wav / (wav.abs().max() + 1e-8) # Normalize audio nps_tensor = torch.tensor(nps, dtype=torch.float32) mel_transform = torchaudio.transforms.MelSpectrogram( sample_rate=SAMPLE_RATE, n_mels=N_MELS, hop_length=HOP_LENGTH, n_fft=2048, ) mel = mel_transform(wav) # mel shape is (n_mels, T_mel), unsqueeze for batch later in run_inference return mel, nps_tensor # mel is (N_MELS, T_mel) # --- INFERENCE --- def run_inference(model, mel_input, nps_input, device): model.eval() with torch.no_grad(): mel = mel_input.to(device).unsqueeze(0) # (1, N_MELS, T_mel) nps = nps_input.to(device).unsqueeze(0) # (1,) mel_cnn_input = mel.unsqueeze(1) # (1, 1, N_MELS, T_mel) conformer_lengths = torch.tensor( [mel_cnn_input.shape[-1]], dtype=torch.long, device=device ) with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, *( [torch.profiler.ProfilerActivity.CUDA] if device.type == "cuda" else [] ), ], record_shapes=True, profile_memory=True, with_stack=False, with_flops=True, ) as prof: out_dict = model(mel_cnn_input, conformer_lengths, nps) print( prof.key_averages().table( sort_by=( "self_cuda_memory_usage" if device.type == "cuda" else "self_cpu_time_total" ), row_limit=20, ) ) energies = out_dict["presence"].squeeze(0).cpu().numpy() don_energy = energies[:, 0] ka_energy = energies[:, 1] drumroll_energy = energies[:, 2] return don_energy, ka_energy, drumroll_energy # --- DECODE TO ONSETS --- def decode_onsets( don_energy, ka_energy, drumroll_energy, hop_sec, threshold=0.5, min_distance_frames=3, ): results = [] T_out = len(don_energy) last_onset_frame = -min_distance_frames for i in range(1, T_out - 1): # Iterate considering neighbors for peak detection if i < last_onset_frame + min_distance_frames: continue e_don, e_ka, e_drum = don_energy[i], ka_energy[i], drumroll_energy[i] energies_at_i = { 1: e_don, 2: e_ka, 5: e_drum, } # Type mapping: 1:Don, 2:Ka, 5:Drumroll # Find which energy is max and if it's a peak above threshold # Sort by energy value descending to prioritize higher energy in case of ties for peak condition sorted_types_by_energy = sorted( energies_at_i.keys(), key=lambda x: energies_at_i[x], reverse=True ) detected_this_frame = False for onset_type in sorted_types_by_energy: current_energy_series = None if onset_type == 1: current_energy_series = don_energy elif onset_type == 2: current_energy_series = ka_energy elif onset_type == 5: current_energy_series = drumroll_energy energy_val = current_energy_series[i] if ( energy_val > threshold and energy_val > current_energy_series[i - 1] and energy_val > current_energy_series[i + 1] ): # Check if this energy is the highest among the three at this frame # This check is implicitly handled by iterating `sorted_types_by_energy` # and breaking after the first detection. results.append((i * hop_sec, onset_type)) last_onset_frame = i detected_this_frame = True break # Only one onset type per frame return results # --- VISUALIZATION --- def plot_results( mel_spectrogram, don_energy, ka_energy, drumroll_energy, onsets, hop_sec, out_path=None, ): # mel_spectrogram is (N_MELS, T_mel) T_mel = mel_spectrogram.shape[1] T_out = len(don_energy) # Length of energy arrays (model output time dimension) # Time axes time_axis_mel = np.arange(T_mel) * (HOP_LENGTH / SAMPLE_RATE) # hop_sec for model output is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE # However, the model output T_out is related to T_mel (input to CNN). # If CNN does not change time dimension, T_out = T_mel. # If TIME_SUB is used for label generation, T_out = T_mel / TIME_SUB. # The `lengths` passed to conformer in `run_inference` is T_mel. # The output of conformer `x` has shape (B, T, D_MODEL). This T is `lengths`. # So, T_out from model is T_mel. # The `hop_sec` for onsets should be based on the model output frame rate. # If model output T_out corresponds to T_mel, then hop_sec for plotting energies is HOP_LENGTH / SAMPLE_RATE. # The `hop_sec` passed to `decode_onsets` is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE. # This seems inconsistent. Let's clarify: `TIME_SUB` is used in `preprocess.py` to determine `T_sub` for labels. # The model's CNN output time dimension T_cnn is used to pad/truncate labels in `collate_fn`. # In `model.py`, the CNN does not stride in time. So T_cnn_out = T_mel_input_to_CNN. # The `lengths` for the conformer is based on this T_cnn_out. # So, the output of the regressor `presence` (B, T_cnn_out, 3) has T_cnn_out time steps. # Each step corresponds to `HOP_LENGTH / SAMPLE_RATE` seconds if TIME_SUB is not involved in downsampling features for the conformer. # Let's assume the `hop_sec` used for `decode_onsets` is correct for interpreting model output frames. time_axis_energies = np.arange(T_out) * hop_sec fig, ax1 = plt.subplots(figsize=(100, 10)) # Plot Mel Spectrogram on ax1 mel_db = torchaudio.functional.amplitude_to_DB( mel_spectrogram, multiplier=10.0, amin=1e-10, db_multiplier=0.0 ) img = ax1.imshow( mel_db.numpy(), aspect="auto", origin="lower", cmap="magma", extent=[time_axis_mel[0], time_axis_mel[-1], 0, N_MELS], ) ax1.set_title("Mel Spectrogram with Predicted Energies and Onsets") ax1.set_xlabel("Time (s)") ax1.set_ylabel("Mel Bin") fig.colorbar(img, ax=ax1, format="%+2.0f dB") # Create a second y-axis for energies ax2 = ax1.twinx() ax2.plot(time_axis_energies, don_energy, label="Don Energy", color="red") ax2.plot(time_axis_energies, ka_energy, label="Ka Energy", color="blue") ax2.plot( time_axis_energies, drumroll_energy, label="Drumroll Energy", color="green" ) ax2.set_ylabel("Energy") ax2.set_ylim(0, 1.2) # Assuming energies are somewhat normalized or bounded # Overlay onsets from decode_onsets (t is already in seconds) labeled_types = set() # Group drumrolls into segments (reuse logic from write_tja) drumroll_times = [t_sec for t_sec, typ in onsets if typ == 5] drumroll_times.sort() drumroll_segments = [] if drumroll_times: seg_start = drumroll_times[0] prev = drumroll_times[0] for t in drumroll_times[1:]: if t - prev <= hop_sec * 6: # up to 5-frame gap prev = t else: drumroll_segments.append((seg_start, prev)) seg_start = t prev = t drumroll_segments.append((seg_start, prev)) # Plot Don/Ka onsets as vertical lines for t_sec, typ in onsets: if typ == 5: continue # skip drumroll onsets color_map = {1: "darkred", 2: "darkblue"} label_map = {1: "Don Onset", 2: "Ka Onset"} line_color = color_map.get(typ, "black") line_label = label_map.get(typ, f"Type {typ} Onset") if typ not in labeled_types: ax1.axvline( t_sec, color=line_color, linestyle="--", alpha=0.9, label=line_label ) labeled_types.add(typ) else: ax1.axvline(t_sec, color=line_color, linestyle="--", alpha=0.9) # Plot drumroll segments as shaded regions for seg_start, seg_end in drumroll_segments: ax1.axvspan( seg_start, seg_end + hop_sec, color="green", alpha=0.2, label="Drumroll Segment" if "drumroll" not in labeled_types else None, ) labeled_types.add("drumroll") # Combine legends from both axes lines, labels = ax1.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() ax2.legend(lines + lines2, labels + labels2, loc="upper right") fig.tight_layout() # Return plot as image buffer or save to file if path provided if out_path: plt.savefig(out_path) print(f"Saved plot to {out_path}") plt.close(fig) return out_path else: # Return plot as in-memory buffer return fig def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav", offset=0): # TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd # Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single) sec_per_beat = 60 / bpm beats_per_measure = 4 # Assuming 4/4 time signature sec_per_measure = sec_per_beat * beats_per_measure # Step 1: Map onsets to (measure_idx, slot, typ) slot_events = [] for t, typ in onsets: measure_idx = int(t // sec_per_measure) t_in_measure = t % sec_per_measure slot = int(round(t_in_measure / sec_per_measure * quantize)) if slot >= quantize: slot = quantize - 1 slot_events.append((measure_idx, slot, typ)) # Step 2: Build measure/slot grid if slot_events: max_measure_idx = max(m for m, _, _ in slot_events) else: max_measure_idx = -1 measures = {i: [0] * quantize for i in range(max_measure_idx + 1)} # Step 3: Place Don/Ka, collect drumrolls drumroll_slots = set() for m, s, typ in slot_events: if typ in [1, 2]: measures[m][s] = typ elif typ == 5: drumroll_slots.add((m, s)) # Step 4: Process drumrolls into contiguous regions, mark 5 (start) and 8 (end) # Flatten all slots to a list of (measure, slot) sorted drumroll_list = sorted(list(drumroll_slots)) # Group into contiguous regions (allowing a gap of 5 slots) grouped = [] group = [] for ms in drumroll_list: if not group: group = [ms] else: last_m, last_s = group[-1] m, s = ms # Calculate slot distance, considering measure wrap slot_dist = None if m == last_m: slot_dist = s - last_s elif m == last_m + 1 and last_s <= quantize - 1: slot_dist = (quantize - 1 - last_s) + s + 1 else: slot_dist = None # Allow gap of up to 5 slots (slot_dist <= 6) if slot_dist is not None and 1 <= slot_dist <= 6: group.append(ms) else: grouped.append(group) group = [ms] if group: grouped.append(group) # Mark 5 (start) and 8 (end) for each group for region in grouped: if len(region) == 1: m, s = region[0] measures[m][s] = 5 # Place 8 in next slot (or next measure if at end) if s < quantize - 1: measures[m][s + 1] = 8 elif m < max_measure_idx: measures[m + 1][0] = 8 else: m_start, s_start = region[0] m_end, s_end = region[-1] measures[m_start][s_start] = 5 measures[m_end][s_end] = 8 # Fill 0 for middle slots (already 0 by default) # Step 5: Generate TJA content tja_content = [] tja_content.append(f"TITLE:{audio} (TC5, {time.strftime('%Y-%m-%d %H:%M:%S')})") tja_content.append(f"BPM:{bpm}") tja_content.append(f"WAVE:{audio}") tja_content.append(f"OFFSET:{offset}") tja_content.append("COURSE:Oni\nLEVEL:9\n") tja_content.append("#START") for i in range(max_measure_idx + 1): notes = measures.get(i, [0] * quantize) line = "".join(str(n) for n in notes) tja_content.append(line + ",") tja_content.append("#END") tja_string = "\n".join(tja_content) # If out_path is provided, also write to file if out_path: with open(out_path, "w", encoding="utf-8") as f: f.write(tja_string) print(f"TJA chart saved to {out_path}") return tja_string