from accelerate.utils import set_seed set_seed(1024) import math import torch from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from datasets import concatenate_datasets import matplotlib.pyplot as plt import numpy as np from .config import ( BATCH_SIZE, DEVICE, EPOCHS, LR, GRAD_ACCUM_STEPS, HOP_LENGTH, NPS_PENALTY_WEIGHT_ALPHA, NPS_PENALTY_WEIGHT_BETA, SAMPLE_RATE, ) from .model import TaikoConformer7 from .dataset import ds from .preprocess import preprocess, collate_fn from .loss import TaikoLoss from huggingface_hub import upload_folder def log_energy_plots_to_tensorboard( writer, tag_prefix, epoch, pred_don, pred_ka, pred_drumroll, true_don, true_ka, true_drumroll, valid_length, hop_sec, ): """ Logs a plot of predicted vs. true energies for one sample to TensorBoard. Energies should be 1D numpy arrays for the single sample, up to valid_length. """ pred_don = pred_don[:valid_length].detach().cpu().numpy() pred_ka = pred_ka[:valid_length].detach().cpu().numpy() pred_drumroll = pred_drumroll[:valid_length].detach().cpu().numpy() true_don = true_don[:valid_length].cpu().numpy() true_ka = true_ka[:valid_length].cpu().numpy() true_drumroll = true_drumroll[:valid_length].cpu().numpy() time_axis = np.arange(valid_length) * hop_sec fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True) fig.suptitle(f"{tag_prefix} - Epoch {epoch}", fontsize=16) axs[0].plot(time_axis, true_don, label="True Don", color="blue", linestyle="--") axs[0].plot(time_axis, pred_don, label="Pred Don", color="lightblue", alpha=0.8) axs[0].set_ylabel("Don Energy") axs[0].legend() axs[0].grid(True) axs[1].plot(time_axis, true_ka, label="True Ka", color="red", linestyle="--") axs[1].plot(time_axis, pred_ka, label="Pred Ka", color="lightcoral", alpha=0.8) axs[1].set_ylabel("Ka Energy") axs[1].legend() axs[1].grid(True) axs[2].plot( time_axis, true_drumroll, label="True Drumroll", color="green", linestyle="--" ) axs[2].plot( time_axis, pred_drumroll, label="Pred Drumroll", color="lightgreen", alpha=0.8 ) axs[2].set_ylabel("Drumroll Energy") axs[2].set_xlabel("Time (s)") axs[2].legend() axs[2].grid(True) plt.tight_layout(rect=[0, 0, 1, 0.96]) writer.add_figure(f"{tag_prefix}/Energy_Comparison", fig, epoch) plt.close(fig) def main(): global ds output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE best_val_loss = float("inf") patience = 10 pat_count = 0 ds_oni = ds.map( preprocess, remove_columns=ds.column_names, fn_kwargs={"difficulty": "oni"}, writer_batch_size=10, ) ds_hard = ds.map( preprocess, remove_columns=ds.column_names, fn_kwargs={"difficulty": "hard"}, writer_batch_size=10, ) ds_normal = ds.map( preprocess, remove_columns=ds.column_names, fn_kwargs={"difficulty": "normal"}, writer_batch_size=10, ) ds = concatenate_datasets([ds_oni, ds_hard, ds_normal]) ds_train_test = ds.train_test_split(test_size=0.1, seed=42) train_loader = DataLoader( ds_train_test["train"], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=8, persistent_workers=True, prefetch_factor=4, ) val_loader = DataLoader( ds_train_test["test"], batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=8, persistent_workers=True, prefetch_factor=4, ) model = TaikoConformer7().to(DEVICE) optimizer = torch.optim.AdamW(model.parameters(), lr=LR) criterion = TaikoLoss( reduction="mean", nps_penalty_weight_alpha=NPS_PENALTY_WEIGHT_ALPHA, nps_penalty_weight_beta=NPS_PENALTY_WEIGHT_BETA, ).to(DEVICE) num_optimizer_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS) total_optimizer_steps = EPOCHS * num_optimizer_steps_per_epoch scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=LR, total_steps=total_optimizer_steps ) writer = SummaryWriter() for epoch in range(1, EPOCHS + 1): model.train() total_epoch_loss = 0.0 optimizer.zero_grad() for idx, batch in enumerate(tqdm(train_loader, desc=f"Train Epoch {epoch}")): mel = batch["mel"].to(DEVICE) lengths = batch["lengths"].to(DEVICE) nps = batch["nps"].to(DEVICE) difficulty = batch["difficulty"].to(DEVICE) level = batch["level"].to(DEVICE) outputs = model(mel, lengths, nps, difficulty, level) loss = criterion(outputs, batch) total_epoch_loss += loss.item() loss = loss / GRAD_ACCUM_STEPS loss.backward() if (idx + 1) % GRAD_ACCUM_STEPS == 0 or (idx + 1) == len(train_loader): torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() writer.add_scalar( "Loss/Train_Step", loss.item() * GRAD_ACCUM_STEPS, epoch * len(train_loader) + idx, ) writer.add_scalar( "LR", scheduler.get_last_lr()[0], epoch * len(train_loader) + idx ) if idx < 3: if mel.size(0) > 0: pred_don = outputs["presence"][0, :, 0] pred_ka = outputs["presence"][0, :, 1] pred_drumroll = outputs["presence"][0, :, 2] true_don = batch["don_labels"][0] true_ka = batch["ka_labels"][0] true_drumroll = batch["drumroll_labels"][0] valid_length = batch["lengths"][0].item() log_energy_plots_to_tensorboard( writer, f"Train_Sample_Batch_{idx}_Sample_0", epoch, pred_don, pred_ka, pred_drumroll, true_don, true_ka, true_drumroll, valid_length, output_frame_hop_sec, ) avg_train_loss = total_epoch_loss / len(train_loader) writer.add_scalar("Loss/Train_Avg", avg_train_loss, epoch) model.eval() total_val_loss = 0.0 with torch.no_grad(): for idx, batch in enumerate(tqdm(val_loader, desc=f"Val Epoch {epoch}")): mel = batch["mel"].to(DEVICE) lengths = batch["lengths"].to(DEVICE) nps = batch["nps"].to(DEVICE) difficulty = batch["difficulty"].to(DEVICE) level = batch["level"].to(DEVICE) outputs = model(mel, lengths, nps, difficulty, level) loss = criterion(outputs, batch) total_val_loss += loss.item() if idx < 3: if mel.size(0) > 0: pred_don = outputs["presence"][0, :, 0] pred_ka = outputs["presence"][0, :, 1] pred_drumroll = outputs["presence"][0, :, 2] true_don = batch["don_labels"][0] true_ka = batch["ka_labels"][0] true_drumroll = batch["drumroll_labels"][0] valid_length = batch["lengths"][0].item() log_energy_plots_to_tensorboard( writer, f"Val_Sample_Batch_{idx}_Sample_0", epoch, pred_don, pred_ka, pred_drumroll, true_don, true_ka, true_drumroll, valid_length, output_frame_hop_sec, ) avg_val_loss = total_val_loss / len(val_loader) writer.add_scalar("Loss/Val_Avg", avg_val_loss, epoch) current_lr = optimizer.param_groups[0]["lr"] writer.add_scalar("LR/learning_rate", current_lr, epoch) if "nps" in batch: writer.add_scalar( "NPS/GT_Val_LastBatch_Avg", batch["nps"].mean().item(), epoch ) print( f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}" ) if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss pat_count = 0 torch.save(model.state_dict(), "best_model.pt") print(f"Saved new best model to best_model.pt at epoch {epoch}") else: pat_count += 1 if pat_count >= patience: print("Early stopping!") break writer.close() model_id = "JacobLinCool/taiko-conformer-7" try: model.push_to_hub( model_id, commit_message=f"Epoch {epoch}, Val Loss: {avg_val_loss:.4f}" ) upload_folder( repo_id=model_id, folder_path="runs", path_in_repo="runs", commit_message="Upload TensorBoard logs", ) except Exception as e: print(f"Error uploading model or logs: {e}") print("Make sure you have the correct permissions and try again.") if __name__ == "__main__": main()