Spaces:
Running
on
Zero
Running
on
Zero
from accelerate.utils import set_seed | |
set_seed(1024) | |
import math | |
import torch | |
from torch.utils.data import DataLoader | |
from torch.utils.tensorboard import SummaryWriter | |
from tqdm import tqdm | |
from datasets import concatenate_datasets | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from .config import ( | |
BATCH_SIZE, | |
DEVICE, | |
EPOCHS, | |
LR, | |
GRAD_ACCUM_STEPS, | |
HOP_LENGTH, | |
SAMPLE_RATE, | |
) | |
from .model import TaikoConformer5 | |
from .dataset import ds | |
from .preprocess import preprocess, collate_fn | |
from .loss import TaikoEnergyLoss | |
from huggingface_hub import upload_folder | |
# --- Helper function to log energy plots --- | |
def log_energy_plots_to_tensorboard( | |
writer, | |
tag_prefix, | |
epoch, | |
pred_don, | |
pred_ka, | |
pred_drumroll, | |
true_don, | |
true_ka, | |
true_drumroll, | |
valid_length, # Actual valid length of the sequence (before padding) | |
hop_sec, | |
): | |
""" | |
Logs a plot of predicted vs. true energies for one sample to TensorBoard. | |
Energies should be 1D numpy arrays for the single sample, up to valid_length. | |
""" | |
# Ensure data is on CPU and converted to numpy, and select only the valid part | |
pred_don = pred_don[:valid_length].detach().cpu().numpy() | |
pred_ka = pred_ka[:valid_length].detach().cpu().numpy() | |
pred_drumroll = pred_drumroll[:valid_length].detach().cpu().numpy() | |
true_don = true_don[:valid_length].cpu().numpy() | |
true_ka = true_ka[:valid_length].cpu().numpy() | |
true_drumroll = true_drumroll[:valid_length].cpu().numpy() | |
time_axis = np.arange(valid_length) * hop_sec | |
fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True) | |
fig.suptitle(f"{tag_prefix} - Epoch {epoch}", fontsize=16) | |
axs[0].plot(time_axis, true_don, label="True Don", color="blue", linestyle="--") | |
axs[0].plot(time_axis, pred_don, label="Pred Don", color="lightblue", alpha=0.8) | |
axs[0].set_ylabel("Don Energy") | |
axs[0].legend() | |
axs[0].grid(True) | |
axs[1].plot(time_axis, true_ka, label="True Ka", color="red", linestyle="--") | |
axs[1].plot(time_axis, pred_ka, label="Pred Ka", color="lightcoral", alpha=0.8) | |
axs[1].set_ylabel("Ka Energy") | |
axs[1].legend() | |
axs[1].grid(True) | |
axs[2].plot( | |
time_axis, true_drumroll, label="True Drumroll", color="green", linestyle="--" | |
) | |
axs[2].plot( | |
time_axis, pred_drumroll, label="Pred Drumroll", color="lightgreen", alpha=0.8 | |
) | |
axs[2].set_ylabel("Drumroll Energy") | |
axs[2].set_xlabel("Time (s)") | |
axs[2].legend() | |
axs[2].grid(True) | |
plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make space for suptitle | |
writer.add_figure(f"{tag_prefix}/Energy_Comparison", fig, epoch) | |
plt.close(fig) | |
def main(): | |
global ds | |
# Calculate hop seconds for model output frames | |
# This assumes the model output time dimension corresponds to the mel spectrogram time dimension | |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE | |
best_val_loss = float("inf") | |
patience = 10 # Increased patience a bit | |
pat_count = 0 | |
ds_oni = ds.map( | |
preprocess, | |
remove_columns=ds.column_names, | |
fn_kwargs={"difficulty": "oni"}, | |
writer_batch_size=10, | |
) | |
ds_hard = ds.map( | |
preprocess, | |
remove_columns=ds.column_names, | |
fn_kwargs={"difficulty": "hard"}, | |
writer_batch_size=10, | |
) | |
ds_normal = ds.map( | |
preprocess, | |
remove_columns=ds.column_names, | |
fn_kwargs={"difficulty": "normal"}, | |
writer_batch_size=10, | |
) | |
ds = concatenate_datasets([ds_oni, ds_hard, ds_normal]) | |
ds_train_test = ds.train_test_split(test_size=0.1, seed=42) | |
train_loader = DataLoader( | |
ds_train_test["train"], | |
batch_size=BATCH_SIZE, | |
shuffle=True, | |
collate_fn=collate_fn, | |
num_workers=2, | |
) | |
val_loader = DataLoader( | |
ds_train_test["test"], | |
batch_size=BATCH_SIZE, | |
shuffle=False, | |
collate_fn=collate_fn, | |
num_workers=2, | |
) | |
model = TaikoConformer5().to(DEVICE) | |
optimizer = torch.optim.AdamW(model.parameters(), lr=LR) | |
criterion = TaikoEnergyLoss(reduction="mean").to(DEVICE) | |
# Adjust scheduler steps for gradient accumulation | |
num_optimizer_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS) | |
total_optimizer_steps = EPOCHS * num_optimizer_steps_per_epoch | |
scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
optimizer, max_lr=LR, total_steps=total_optimizer_steps | |
) | |
writer = SummaryWriter() | |
for epoch in range(1, EPOCHS + 1): | |
model.train() | |
total_epoch_loss = 0.0 | |
optimizer.zero_grad() | |
for idx, batch in enumerate(tqdm(train_loader, desc=f"Train Epoch {epoch}")): | |
mel = batch["mel"].to(DEVICE) | |
# Unpack new energy-based labels | |
don_labels = batch["don_labels"].to(DEVICE) | |
ka_labels = batch["ka_labels"].to(DEVICE) | |
drumroll_labels = batch["drumroll_labels"].to(DEVICE) | |
lengths = batch["lengths"].to( | |
DEVICE | |
) # These are for the Conformer model output | |
nps = batch["nps"].to(DEVICE) | |
output_dict = model(mel, lengths, nps) | |
# output_dict["presence"] is now (B, T_out, 3) for don, ka, drumroll energies | |
pred_energies_batch = output_dict["presence"] # (B, T_out, 3) | |
loss_input_batch = { | |
"don_labels": don_labels, | |
"ka_labels": ka_labels, | |
"drumroll_labels": drumroll_labels, | |
"lengths": lengths, # Pass lengths for masking within the loss function | |
} | |
loss = criterion(output_dict, loss_input_batch) | |
(loss / GRAD_ACCUM_STEPS).backward() | |
if (idx + 1) % GRAD_ACCUM_STEPS == 0 or (idx + 1) == len(train_loader): | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
optimizer.step() | |
scheduler.step() | |
optimizer.zero_grad() | |
total_epoch_loss += loss.item() | |
# Log plot for the first sample of the first batch in each training epoch | |
if idx == 0: | |
first_sample_pred_don = pred_energies_batch[0, :, 0] | |
first_sample_pred_ka = pred_energies_batch[0, :, 1] | |
first_sample_pred_drumroll = pred_energies_batch[0, :, 2] | |
first_sample_true_don = don_labels[0, :] | |
first_sample_true_ka = ka_labels[0, :] | |
first_sample_true_drumroll = drumroll_labels[0, :] | |
first_sample_length = lengths[ | |
0 | |
].item() # Get the valid length of the first sample | |
log_energy_plots_to_tensorboard( | |
writer, | |
"Train/Sample_0", | |
epoch, | |
first_sample_pred_don, | |
first_sample_pred_ka, | |
first_sample_pred_drumroll, | |
first_sample_true_don, | |
first_sample_true_ka, | |
first_sample_true_drumroll, | |
first_sample_length, | |
output_frame_hop_sec, | |
) | |
avg_train_loss = total_epoch_loss / len(train_loader) | |
writer.add_scalar("Loss/Train_Avg", avg_train_loss, epoch) | |
# Validation | |
model.eval() | |
total_val_loss = 0.0 | |
# Removed storage for classification logits/labels and confusion matrix components | |
with torch.no_grad(): | |
for val_idx, batch in enumerate( | |
tqdm(val_loader, desc=f"Val Epoch {epoch}") | |
): | |
mel = batch["mel"].to(DEVICE) | |
don_labels = batch["don_labels"].to(DEVICE) | |
ka_labels = batch["ka_labels"].to(DEVICE) | |
drumroll_labels = batch["drumroll_labels"].to(DEVICE) | |
lengths = batch["lengths"].to(DEVICE) | |
nps = batch["nps"].to(DEVICE) # Ground truth NPS from batch | |
output_dict = model(mel, lengths, nps) | |
pred_energies_val_batch = output_dict["presence"] # (B, T_out, 3) | |
val_loss_input_batch = { | |
"don_labels": don_labels, | |
"ka_labels": ka_labels, | |
"drumroll_labels": drumroll_labels, | |
"lengths": lengths, | |
} | |
val_loss = criterion(output_dict, val_loss_input_batch) | |
total_val_loss += val_loss.item() | |
# Log plot for the first sample of the first batch in each validation epoch | |
if val_idx == 0: | |
first_val_sample_pred_don = pred_energies_val_batch[0, :, 0] | |
first_val_sample_pred_ka = pred_energies_val_batch[0, :, 1] | |
first_val_sample_pred_drumroll = pred_energies_val_batch[0, :, 2] | |
first_val_sample_true_don = don_labels[0, :] | |
first_val_sample_true_ka = ka_labels[0, :] | |
first_val_sample_true_drumroll = drumroll_labels[0, :] | |
first_val_sample_length = lengths[0].item() | |
log_energy_plots_to_tensorboard( | |
writer, | |
"Eval/Sample_0", | |
epoch, | |
first_val_sample_pred_don, | |
first_val_sample_pred_ka, | |
first_val_sample_pred_drumroll, | |
first_val_sample_true_don, | |
first_val_sample_true_ka, | |
first_val_sample_true_drumroll, | |
first_val_sample_length, | |
output_frame_hop_sec, | |
) | |
# Log ground truth NPS for reference during validation if needed | |
# writer.add_scalar("NPS/GT_Val_Batch_Avg", nps.mean().item(), epoch * len(val_loader) + idx) | |
avg_val_loss = total_val_loss / len(val_loader) | |
writer.add_scalar("Loss/Val_Avg", avg_val_loss, epoch) | |
# Log learning rate | |
current_lr = optimizer.param_groups[0]["lr"] | |
writer.add_scalar("LR/learning_rate", current_lr, epoch) | |
# Log ground truth NPS from the last validation batch (or mean over epoch) | |
if "nps" in batch: # Check if nps is in the last batch | |
writer.add_scalar( | |
"NPS/GT_Val_LastBatch_Avg", batch["nps"].mean().item(), epoch | |
) | |
print( | |
f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}" | |
) | |
if avg_val_loss < best_val_loss: | |
best_val_loss = avg_val_loss | |
pat_count = 0 | |
torch.save(model.state_dict(), "best_model.pt") # Changed model save name | |
print(f"Saved new best model to best_model.pt at epoch {epoch}") | |
else: | |
pat_count += 1 | |
if pat_count >= patience: | |
print("Early stopping!") | |
break | |
writer.close() | |
model_id = "JacobLinCool/taiko-conformer-5" | |
try: | |
model.push_to_hub(model_id, commit_message="Upload trained model") | |
upload_folder( | |
repo_id=model_id, | |
folder_path="runs", | |
path_in_repo=".", | |
commit_message="Upload training logs", | |
ignore_patterns=["*.txt", "*.json", "*.csv"], | |
) | |
print(f"Model and logs uploaded to {model_id}") | |
except Exception as e: | |
print(f"Error uploading to Hugging Face Hub: {e}") | |
print("Make sure you have the correct permissions and try again.") | |
if __name__ == "__main__": | |
main() | |