from accelerate.utils import set_seed set_seed(1024) import math import torch from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from datasets import concatenate_datasets import matplotlib.pyplot as plt import numpy as np from .config import ( BATCH_SIZE, DEVICE, EPOCHS, LR, GRAD_ACCUM_STEPS, HOP_LENGTH, SAMPLE_RATE, ) from .model import TaikoConformer5 from .dataset import ds from .preprocess import preprocess, collate_fn from .loss import TaikoEnergyLoss from huggingface_hub import upload_folder # --- Helper function to log energy plots --- def log_energy_plots_to_tensorboard( writer, tag_prefix, epoch, pred_don, pred_ka, pred_drumroll, true_don, true_ka, true_drumroll, valid_length, # Actual valid length of the sequence (before padding) hop_sec, ): """ Logs a plot of predicted vs. true energies for one sample to TensorBoard. Energies should be 1D numpy arrays for the single sample, up to valid_length. """ # Ensure data is on CPU and converted to numpy, and select only the valid part pred_don = pred_don[:valid_length].detach().cpu().numpy() pred_ka = pred_ka[:valid_length].detach().cpu().numpy() pred_drumroll = pred_drumroll[:valid_length].detach().cpu().numpy() true_don = true_don[:valid_length].cpu().numpy() true_ka = true_ka[:valid_length].cpu().numpy() true_drumroll = true_drumroll[:valid_length].cpu().numpy() time_axis = np.arange(valid_length) * hop_sec fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True) fig.suptitle(f"{tag_prefix} - Epoch {epoch}", fontsize=16) axs[0].plot(time_axis, true_don, label="True Don", color="blue", linestyle="--") axs[0].plot(time_axis, pred_don, label="Pred Don", color="lightblue", alpha=0.8) axs[0].set_ylabel("Don Energy") axs[0].legend() axs[0].grid(True) axs[1].plot(time_axis, true_ka, label="True Ka", color="red", linestyle="--") axs[1].plot(time_axis, pred_ka, label="Pred Ka", color="lightcoral", alpha=0.8) axs[1].set_ylabel("Ka Energy") axs[1].legend() axs[1].grid(True) axs[2].plot( time_axis, true_drumroll, label="True Drumroll", color="green", linestyle="--" ) axs[2].plot( time_axis, pred_drumroll, label="Pred Drumroll", color="lightgreen", alpha=0.8 ) axs[2].set_ylabel("Drumroll Energy") axs[2].set_xlabel("Time (s)") axs[2].legend() axs[2].grid(True) plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make space for suptitle writer.add_figure(f"{tag_prefix}/Energy_Comparison", fig, epoch) plt.close(fig) def main(): global ds # Calculate hop seconds for model output frames # This assumes the model output time dimension corresponds to the mel spectrogram time dimension output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE best_val_loss = float("inf") patience = 10 # Increased patience a bit pat_count = 0 ds_oni = ds.map( preprocess, remove_columns=ds.column_names, fn_kwargs={"difficulty": "oni"}, writer_batch_size=10, ) ds_hard = ds.map( preprocess, remove_columns=ds.column_names, fn_kwargs={"difficulty": "hard"}, writer_batch_size=10, ) ds_normal = ds.map( preprocess, remove_columns=ds.column_names, fn_kwargs={"difficulty": "normal"}, writer_batch_size=10, ) ds = concatenate_datasets([ds_oni, ds_hard, ds_normal]) ds_train_test = ds.train_test_split(test_size=0.1, seed=42) train_loader = DataLoader( ds_train_test["train"], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=2, ) val_loader = DataLoader( ds_train_test["test"], batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=2, ) model = TaikoConformer5().to(DEVICE) optimizer = torch.optim.AdamW(model.parameters(), lr=LR) criterion = TaikoEnergyLoss(reduction="mean").to(DEVICE) # Adjust scheduler steps for gradient accumulation num_optimizer_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS) total_optimizer_steps = EPOCHS * num_optimizer_steps_per_epoch scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=LR, total_steps=total_optimizer_steps ) writer = SummaryWriter() for epoch in range(1, EPOCHS + 1): model.train() total_epoch_loss = 0.0 optimizer.zero_grad() for idx, batch in enumerate(tqdm(train_loader, desc=f"Train Epoch {epoch}")): mel = batch["mel"].to(DEVICE) # Unpack new energy-based labels don_labels = batch["don_labels"].to(DEVICE) ka_labels = batch["ka_labels"].to(DEVICE) drumroll_labels = batch["drumroll_labels"].to(DEVICE) lengths = batch["lengths"].to( DEVICE ) # These are for the Conformer model output nps = batch["nps"].to(DEVICE) output_dict = model(mel, lengths, nps) # output_dict["presence"] is now (B, T_out, 3) for don, ka, drumroll energies pred_energies_batch = output_dict["presence"] # (B, T_out, 3) loss_input_batch = { "don_labels": don_labels, "ka_labels": ka_labels, "drumroll_labels": drumroll_labels, "lengths": lengths, # Pass lengths for masking within the loss function } loss = criterion(output_dict, loss_input_batch) (loss / GRAD_ACCUM_STEPS).backward() if (idx + 1) % GRAD_ACCUM_STEPS == 0 or (idx + 1) == len(train_loader): torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() total_epoch_loss += loss.item() # Log plot for the first sample of the first batch in each training epoch if idx == 0: first_sample_pred_don = pred_energies_batch[0, :, 0] first_sample_pred_ka = pred_energies_batch[0, :, 1] first_sample_pred_drumroll = pred_energies_batch[0, :, 2] first_sample_true_don = don_labels[0, :] first_sample_true_ka = ka_labels[0, :] first_sample_true_drumroll = drumroll_labels[0, :] first_sample_length = lengths[ 0 ].item() # Get the valid length of the first sample log_energy_plots_to_tensorboard( writer, "Train/Sample_0", epoch, first_sample_pred_don, first_sample_pred_ka, first_sample_pred_drumroll, first_sample_true_don, first_sample_true_ka, first_sample_true_drumroll, first_sample_length, output_frame_hop_sec, ) avg_train_loss = total_epoch_loss / len(train_loader) writer.add_scalar("Loss/Train_Avg", avg_train_loss, epoch) # Validation model.eval() total_val_loss = 0.0 # Removed storage for classification logits/labels and confusion matrix components with torch.no_grad(): for val_idx, batch in enumerate( tqdm(val_loader, desc=f"Val Epoch {epoch}") ): mel = batch["mel"].to(DEVICE) don_labels = batch["don_labels"].to(DEVICE) ka_labels = batch["ka_labels"].to(DEVICE) drumroll_labels = batch["drumroll_labels"].to(DEVICE) lengths = batch["lengths"].to(DEVICE) nps = batch["nps"].to(DEVICE) # Ground truth NPS from batch output_dict = model(mel, lengths, nps) pred_energies_val_batch = output_dict["presence"] # (B, T_out, 3) val_loss_input_batch = { "don_labels": don_labels, "ka_labels": ka_labels, "drumroll_labels": drumroll_labels, "lengths": lengths, } val_loss = criterion(output_dict, val_loss_input_batch) total_val_loss += val_loss.item() # Log plot for the first sample of the first batch in each validation epoch if val_idx == 0: first_val_sample_pred_don = pred_energies_val_batch[0, :, 0] first_val_sample_pred_ka = pred_energies_val_batch[0, :, 1] first_val_sample_pred_drumroll = pred_energies_val_batch[0, :, 2] first_val_sample_true_don = don_labels[0, :] first_val_sample_true_ka = ka_labels[0, :] first_val_sample_true_drumroll = drumroll_labels[0, :] first_val_sample_length = lengths[0].item() log_energy_plots_to_tensorboard( writer, "Eval/Sample_0", epoch, first_val_sample_pred_don, first_val_sample_pred_ka, first_val_sample_pred_drumroll, first_val_sample_true_don, first_val_sample_true_ka, first_val_sample_true_drumroll, first_val_sample_length, output_frame_hop_sec, ) # Log ground truth NPS for reference during validation if needed # writer.add_scalar("NPS/GT_Val_Batch_Avg", nps.mean().item(), epoch * len(val_loader) + idx) avg_val_loss = total_val_loss / len(val_loader) writer.add_scalar("Loss/Val_Avg", avg_val_loss, epoch) # Log learning rate current_lr = optimizer.param_groups[0]["lr"] writer.add_scalar("LR/learning_rate", current_lr, epoch) # Log ground truth NPS from the last validation batch (or mean over epoch) if "nps" in batch: # Check if nps is in the last batch writer.add_scalar( "NPS/GT_Val_LastBatch_Avg", batch["nps"].mean().item(), epoch ) print( f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}" ) if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss pat_count = 0 torch.save(model.state_dict(), "best_model.pt") # Changed model save name print(f"Saved new best model to best_model.pt at epoch {epoch}") else: pat_count += 1 if pat_count >= patience: print("Early stopping!") break writer.close() model_id = "JacobLinCool/taiko-conformer-5" try: model.push_to_hub(model_id, commit_message="Upload trained model") upload_folder( repo_id=model_id, folder_path="runs", path_in_repo=".", commit_message="Upload training logs", ignore_patterns=["*.txt", "*.json", "*.csv"], ) print(f"Model and logs uploaded to {model_id}") except Exception as e: print(f"Error uploading to Hugging Face Hub: {e}") print("Make sure you have the correct permissions and try again.") if __name__ == "__main__": main()