Spaces:
Running
on
Zero
Running
on
Zero

Add offset parameter to TJA writing functions and update inference methods for TC5, TC6, and TC7
db8b2d5
import time | |
import torch | |
import torchaudio | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from .config import SAMPLE_RATE, N_MELS, HOP_LENGTH | |
import torch.profiler | |
# --- PREPROCESSING (match training) --- | |
def preprocess_audio(audio_path, nps=5.0): | |
wav, sr = torchaudio.load(audio_path) | |
wav = wav.mean(dim=0) # mono | |
if sr != SAMPLE_RATE: | |
wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE) | |
wav = wav / (wav.abs().max() + 1e-8) # Normalize audio | |
nps_tensor = torch.tensor(nps, dtype=torch.float32) | |
mel_transform = torchaudio.transforms.MelSpectrogram( | |
sample_rate=SAMPLE_RATE, | |
n_mels=N_MELS, | |
hop_length=HOP_LENGTH, | |
n_fft=2048, | |
) | |
mel = mel_transform(wav) | |
# mel shape is (n_mels, T_mel), unsqueeze for batch later in run_inference | |
return mel, nps_tensor # mel is (N_MELS, T_mel) | |
# --- INFERENCE --- | |
def run_inference(model, mel_input, nps_input, device): | |
model.eval() | |
with torch.no_grad(): | |
mel = mel_input.to(device).unsqueeze(0) # (1, N_MELS, T_mel) | |
nps = nps_input.to(device).unsqueeze(0) # (1,) | |
mel_cnn_input = mel.unsqueeze(1) # (1, 1, N_MELS, T_mel) | |
conformer_lengths = torch.tensor( | |
[mel_cnn_input.shape[-1]], dtype=torch.long, device=device | |
) | |
with torch.profiler.profile( | |
activities=[ | |
torch.profiler.ProfilerActivity.CPU, | |
*( | |
[torch.profiler.ProfilerActivity.CUDA] | |
if device.type == "cuda" | |
else [] | |
), | |
], | |
record_shapes=True, | |
profile_memory=True, | |
with_stack=False, | |
with_flops=True, | |
) as prof: | |
out_dict = model(mel_cnn_input, conformer_lengths, nps) | |
print( | |
prof.key_averages().table( | |
sort_by=( | |
"self_cuda_memory_usage" | |
if device.type == "cuda" | |
else "self_cpu_time_total" | |
), | |
row_limit=20, | |
) | |
) | |
energies = out_dict["presence"].squeeze(0).cpu().numpy() | |
don_energy = energies[:, 0] | |
ka_energy = energies[:, 1] | |
drumroll_energy = energies[:, 2] | |
return don_energy, ka_energy, drumroll_energy | |
# --- DECODE TO ONSETS --- | |
def decode_onsets( | |
don_energy, | |
ka_energy, | |
drumroll_energy, | |
hop_sec, | |
threshold=0.5, | |
min_distance_frames=3, | |
): | |
results = [] | |
T_out = len(don_energy) | |
last_onset_frame = -min_distance_frames | |
for i in range(1, T_out - 1): # Iterate considering neighbors for peak detection | |
if i < last_onset_frame + min_distance_frames: | |
continue | |
e_don, e_ka, e_drum = don_energy[i], ka_energy[i], drumroll_energy[i] | |
energies_at_i = { | |
1: e_don, | |
2: e_ka, | |
5: e_drum, | |
} # Type mapping: 1:Don, 2:Ka, 5:Drumroll | |
# Find which energy is max and if it's a peak above threshold | |
# Sort by energy value descending to prioritize higher energy in case of ties for peak condition | |
sorted_types_by_energy = sorted( | |
energies_at_i.keys(), key=lambda x: energies_at_i[x], reverse=True | |
) | |
detected_this_frame = False | |
for onset_type in sorted_types_by_energy: | |
current_energy_series = None | |
if onset_type == 1: | |
current_energy_series = don_energy | |
elif onset_type == 2: | |
current_energy_series = ka_energy | |
elif onset_type == 5: | |
current_energy_series = drumroll_energy | |
energy_val = current_energy_series[i] | |
if ( | |
energy_val > threshold | |
and energy_val > current_energy_series[i - 1] | |
and energy_val > current_energy_series[i + 1] | |
): | |
# Check if this energy is the highest among the three at this frame | |
# This check is implicitly handled by iterating `sorted_types_by_energy` | |
# and breaking after the first detection. | |
results.append((i * hop_sec, onset_type)) | |
last_onset_frame = i | |
detected_this_frame = True | |
break # Only one onset type per frame | |
return results | |
# --- VISUALIZATION --- | |
def plot_results( | |
mel_spectrogram, | |
don_energy, | |
ka_energy, | |
drumroll_energy, | |
onsets, | |
hop_sec, | |
out_path=None, | |
): | |
# mel_spectrogram is (N_MELS, T_mel) | |
T_mel = mel_spectrogram.shape[1] | |
T_out = len(don_energy) # Length of energy arrays (model output time dimension) | |
# Time axes | |
time_axis_mel = np.arange(T_mel) * (HOP_LENGTH / SAMPLE_RATE) | |
# hop_sec for model output is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE | |
# However, the model output T_out is related to T_mel (input to CNN). | |
# If CNN does not change time dimension, T_out = T_mel. | |
# If TIME_SUB is used for label generation, T_out = T_mel / TIME_SUB. | |
# The `lengths` passed to conformer in `run_inference` is T_mel. | |
# The output of conformer `x` has shape (B, T, D_MODEL). This T is `lengths`. | |
# So, T_out from model is T_mel. | |
# The `hop_sec` for onsets should be based on the model output frame rate. | |
# If model output T_out corresponds to T_mel, then hop_sec for plotting energies is HOP_LENGTH / SAMPLE_RATE. | |
# The `hop_sec` passed to `decode_onsets` is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE. | |
# This seems inconsistent. Let's clarify: `TIME_SUB` is used in `preprocess.py` to determine `T_sub` for labels. | |
# The model's CNN output time dimension T_cnn is used to pad/truncate labels in `collate_fn`. | |
# In `model.py`, the CNN does not stride in time. So T_cnn_out = T_mel_input_to_CNN. | |
# The `lengths` for the conformer is based on this T_cnn_out. | |
# So, the output of the regressor `presence` (B, T_cnn_out, 3) has T_cnn_out time steps. | |
# Each step corresponds to `HOP_LENGTH / SAMPLE_RATE` seconds if TIME_SUB is not involved in downsampling features for the conformer. | |
# Let's assume the `hop_sec` used for `decode_onsets` is correct for interpreting model output frames. | |
time_axis_energies = np.arange(T_out) * hop_sec | |
fig, ax1 = plt.subplots(figsize=(100, 10)) | |
# Plot Mel Spectrogram on ax1 | |
mel_db = torchaudio.functional.amplitude_to_DB( | |
mel_spectrogram, multiplier=10.0, amin=1e-10, db_multiplier=0.0 | |
) | |
img = ax1.imshow( | |
mel_db.numpy(), | |
aspect="auto", | |
origin="lower", | |
cmap="magma", | |
extent=[time_axis_mel[0], time_axis_mel[-1], 0, N_MELS], | |
) | |
ax1.set_title("Mel Spectrogram with Predicted Energies and Onsets") | |
ax1.set_xlabel("Time (s)") | |
ax1.set_ylabel("Mel Bin") | |
fig.colorbar(img, ax=ax1, format="%+2.0f dB") | |
# Create a second y-axis for energies | |
ax2 = ax1.twinx() | |
ax2.plot(time_axis_energies, don_energy, label="Don Energy", color="red") | |
ax2.plot(time_axis_energies, ka_energy, label="Ka Energy", color="blue") | |
ax2.plot( | |
time_axis_energies, drumroll_energy, label="Drumroll Energy", color="green" | |
) | |
ax2.set_ylabel("Energy") | |
ax2.set_ylim(0, 1.2) # Assuming energies are somewhat normalized or bounded | |
# Overlay onsets from decode_onsets (t is already in seconds) | |
labeled_types = set() | |
# Group drumrolls into segments (reuse logic from write_tja) | |
drumroll_times = [t_sec for t_sec, typ in onsets if typ == 5] | |
drumroll_times.sort() | |
drumroll_segments = [] | |
if drumroll_times: | |
seg_start = drumroll_times[0] | |
prev = drumroll_times[0] | |
for t in drumroll_times[1:]: | |
if t - prev <= hop_sec * 6: # up to 5-frame gap | |
prev = t | |
else: | |
drumroll_segments.append((seg_start, prev)) | |
seg_start = t | |
prev = t | |
drumroll_segments.append((seg_start, prev)) | |
# Plot Don/Ka onsets as vertical lines | |
for t_sec, typ in onsets: | |
if typ == 5: | |
continue # skip drumroll onsets | |
color_map = {1: "darkred", 2: "darkblue"} | |
label_map = {1: "Don Onset", 2: "Ka Onset"} | |
line_color = color_map.get(typ, "black") | |
line_label = label_map.get(typ, f"Type {typ} Onset") | |
if typ not in labeled_types: | |
ax1.axvline( | |
t_sec, color=line_color, linestyle="--", alpha=0.9, label=line_label | |
) | |
labeled_types.add(typ) | |
else: | |
ax1.axvline(t_sec, color=line_color, linestyle="--", alpha=0.9) | |
# Plot drumroll segments as shaded regions | |
for seg_start, seg_end in drumroll_segments: | |
ax1.axvspan( | |
seg_start, | |
seg_end + hop_sec, | |
color="green", | |
alpha=0.2, | |
label="Drumroll Segment" if "drumroll" not in labeled_types else None, | |
) | |
labeled_types.add("drumroll") | |
# Combine legends from both axes | |
lines, labels = ax1.get_legend_handles_labels() | |
lines2, labels2 = ax2.get_legend_handles_labels() | |
ax2.legend(lines + lines2, labels + labels2, loc="upper right") | |
fig.tight_layout() | |
# Return plot as image buffer or save to file if path provided | |
if out_path: | |
plt.savefig(out_path) | |
print(f"Saved plot to {out_path}") | |
plt.close(fig) | |
return out_path | |
else: | |
# Return plot as in-memory buffer | |
return fig | |
def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav", offset=0): | |
# TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd | |
# Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single) | |
sec_per_beat = 60 / bpm | |
beats_per_measure = 4 # Assuming 4/4 time signature | |
sec_per_measure = sec_per_beat * beats_per_measure | |
# Step 1: Map onsets to (measure_idx, slot, typ) | |
slot_events = [] | |
for t, typ in onsets: | |
measure_idx = int(t // sec_per_measure) | |
t_in_measure = t % sec_per_measure | |
slot = int(round(t_in_measure / sec_per_measure * quantize)) | |
if slot >= quantize: | |
slot = quantize - 1 | |
slot_events.append((measure_idx, slot, typ)) | |
# Step 2: Build measure/slot grid | |
if slot_events: | |
max_measure_idx = max(m for m, _, _ in slot_events) | |
else: | |
max_measure_idx = -1 | |
measures = {i: [0] * quantize for i in range(max_measure_idx + 1)} | |
# Step 3: Place Don/Ka, collect drumrolls | |
drumroll_slots = set() | |
for m, s, typ in slot_events: | |
if typ in [1, 2]: | |
measures[m][s] = typ | |
elif typ == 5: | |
drumroll_slots.add((m, s)) | |
# Step 4: Process drumrolls into contiguous regions, mark 5 (start) and 8 (end) | |
# Flatten all slots to a list of (measure, slot) sorted | |
drumroll_list = sorted(list(drumroll_slots)) | |
# Group into contiguous regions (allowing a gap of 5 slots) | |
grouped = [] | |
group = [] | |
for ms in drumroll_list: | |
if not group: | |
group = [ms] | |
else: | |
last_m, last_s = group[-1] | |
m, s = ms | |
# Calculate slot distance, considering measure wrap | |
slot_dist = None | |
if m == last_m: | |
slot_dist = s - last_s | |
elif m == last_m + 1 and last_s <= quantize - 1: | |
slot_dist = (quantize - 1 - last_s) + s + 1 | |
else: | |
slot_dist = None | |
# Allow gap of up to 5 slots (slot_dist <= 6) | |
if slot_dist is not None and 1 <= slot_dist <= 6: | |
group.append(ms) | |
else: | |
grouped.append(group) | |
group = [ms] | |
if group: | |
grouped.append(group) | |
# Mark 5 (start) and 8 (end) for each group | |
for region in grouped: | |
if len(region) == 1: | |
m, s = region[0] | |
measures[m][s] = 5 | |
# Place 8 in next slot (or next measure if at end) | |
if s < quantize - 1: | |
measures[m][s + 1] = 8 | |
elif m < max_measure_idx: | |
measures[m + 1][0] = 8 | |
else: | |
m_start, s_start = region[0] | |
m_end, s_end = region[-1] | |
measures[m_start][s_start] = 5 | |
measures[m_end][s_end] = 8 | |
# Fill 0 for middle slots (already 0 by default) | |
# Step 5: Generate TJA content | |
tja_content = [] | |
tja_content.append(f"TITLE:{audio} (TC5, {time.strftime('%Y-%m-%d %H:%M:%S')})") | |
tja_content.append(f"BPM:{bpm}") | |
tja_content.append(f"WAVE:{audio}") | |
tja_content.append(f"OFFSET:{offset}") | |
tja_content.append("COURSE:Oni\nLEVEL:9\n") | |
tja_content.append("#START") | |
for i in range(max_measure_idx + 1): | |
notes = measures.get(i, [0] * quantize) | |
line = "".join(str(n) for n in notes) | |
tja_content.append(line + ",") | |
tja_content.append("#END") | |
tja_string = "\n".join(tja_content) | |
# If out_path is provided, also write to file | |
if out_path: | |
with open(out_path, "w", encoding="utf-8") as f: | |
f.write(tja_string) | |
print(f"TJA chart saved to {out_path}") | |
return tja_string | |