tc5-exp / tc7 /train.py
JacobLinCool's picture
Implement TaikoConformer7 model, loss function, preprocessing, and training pipeline
812b01c
from accelerate.utils import set_seed
set_seed(1024)
import math
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from datasets import concatenate_datasets
import matplotlib.pyplot as plt
import numpy as np
from .config import (
BATCH_SIZE,
DEVICE,
EPOCHS,
LR,
GRAD_ACCUM_STEPS,
HOP_LENGTH,
NPS_PENALTY_WEIGHT_ALPHA,
NPS_PENALTY_WEIGHT_BETA,
SAMPLE_RATE,
)
from .model import TaikoConformer7
from .dataset import ds
from .preprocess import preprocess, collate_fn
from .loss import TaikoLoss
from huggingface_hub import upload_folder
def log_energy_plots_to_tensorboard(
writer,
tag_prefix,
epoch,
pred_don,
pred_ka,
pred_drumroll,
true_don,
true_ka,
true_drumroll,
valid_length,
hop_sec,
):
"""
Logs a plot of predicted vs. true energies for one sample to TensorBoard.
Energies should be 1D numpy arrays for the single sample, up to valid_length.
"""
pred_don = pred_don[:valid_length].detach().cpu().numpy()
pred_ka = pred_ka[:valid_length].detach().cpu().numpy()
pred_drumroll = pred_drumroll[:valid_length].detach().cpu().numpy()
true_don = true_don[:valid_length].cpu().numpy()
true_ka = true_ka[:valid_length].cpu().numpy()
true_drumroll = true_drumroll[:valid_length].cpu().numpy()
time_axis = np.arange(valid_length) * hop_sec
fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True)
fig.suptitle(f"{tag_prefix} - Epoch {epoch}", fontsize=16)
axs[0].plot(time_axis, true_don, label="True Don", color="blue", linestyle="--")
axs[0].plot(time_axis, pred_don, label="Pred Don", color="lightblue", alpha=0.8)
axs[0].set_ylabel("Don Energy")
axs[0].legend()
axs[0].grid(True)
axs[1].plot(time_axis, true_ka, label="True Ka", color="red", linestyle="--")
axs[1].plot(time_axis, pred_ka, label="Pred Ka", color="lightcoral", alpha=0.8)
axs[1].set_ylabel("Ka Energy")
axs[1].legend()
axs[1].grid(True)
axs[2].plot(
time_axis, true_drumroll, label="True Drumroll", color="green", linestyle="--"
)
axs[2].plot(
time_axis, pred_drumroll, label="Pred Drumroll", color="lightgreen", alpha=0.8
)
axs[2].set_ylabel("Drumroll Energy")
axs[2].set_xlabel("Time (s)")
axs[2].legend()
axs[2].grid(True)
plt.tight_layout(rect=[0, 0, 1, 0.96])
writer.add_figure(f"{tag_prefix}/Energy_Comparison", fig, epoch)
plt.close(fig)
def main():
global ds
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
best_val_loss = float("inf")
patience = 10
pat_count = 0
ds_oni = ds.map(
preprocess,
remove_columns=ds.column_names,
fn_kwargs={"difficulty": "oni"},
writer_batch_size=10,
)
ds_hard = ds.map(
preprocess,
remove_columns=ds.column_names,
fn_kwargs={"difficulty": "hard"},
writer_batch_size=10,
)
ds_normal = ds.map(
preprocess,
remove_columns=ds.column_names,
fn_kwargs={"difficulty": "normal"},
writer_batch_size=10,
)
ds = concatenate_datasets([ds_oni, ds_hard, ds_normal])
ds_train_test = ds.train_test_split(test_size=0.1, seed=42)
train_loader = DataLoader(
ds_train_test["train"],
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=collate_fn,
num_workers=8,
persistent_workers=True,
prefetch_factor=4,
)
val_loader = DataLoader(
ds_train_test["test"],
batch_size=BATCH_SIZE,
shuffle=False,
collate_fn=collate_fn,
num_workers=8,
persistent_workers=True,
prefetch_factor=4,
)
model = TaikoConformer7().to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
criterion = TaikoLoss(
reduction="mean",
nps_penalty_weight_alpha=NPS_PENALTY_WEIGHT_ALPHA,
nps_penalty_weight_beta=NPS_PENALTY_WEIGHT_BETA,
).to(DEVICE)
num_optimizer_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
total_optimizer_steps = EPOCHS * num_optimizer_steps_per_epoch
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=LR, total_steps=total_optimizer_steps
)
writer = SummaryWriter()
for epoch in range(1, EPOCHS + 1):
model.train()
total_epoch_loss = 0.0
optimizer.zero_grad()
for idx, batch in enumerate(tqdm(train_loader, desc=f"Train Epoch {epoch}")):
mel = batch["mel"].to(DEVICE)
lengths = batch["lengths"].to(DEVICE)
nps = batch["nps"].to(DEVICE)
difficulty = batch["difficulty"].to(DEVICE)
level = batch["level"].to(DEVICE)
outputs = model(mel, lengths, nps, difficulty, level)
loss = criterion(outputs, batch)
total_epoch_loss += loss.item()
loss = loss / GRAD_ACCUM_STEPS
loss.backward()
if (idx + 1) % GRAD_ACCUM_STEPS == 0 or (idx + 1) == len(train_loader):
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
writer.add_scalar(
"Loss/Train_Step",
loss.item() * GRAD_ACCUM_STEPS,
epoch * len(train_loader) + idx,
)
writer.add_scalar(
"LR", scheduler.get_last_lr()[0], epoch * len(train_loader) + idx
)
if idx < 3:
if mel.size(0) > 0:
pred_don = outputs["presence"][0, :, 0]
pred_ka = outputs["presence"][0, :, 1]
pred_drumroll = outputs["presence"][0, :, 2]
true_don = batch["don_labels"][0]
true_ka = batch["ka_labels"][0]
true_drumroll = batch["drumroll_labels"][0]
valid_length = batch["lengths"][0].item()
log_energy_plots_to_tensorboard(
writer,
f"Train_Sample_Batch_{idx}_Sample_0",
epoch,
pred_don,
pred_ka,
pred_drumroll,
true_don,
true_ka,
true_drumroll,
valid_length,
output_frame_hop_sec,
)
avg_train_loss = total_epoch_loss / len(train_loader)
writer.add_scalar("Loss/Train_Avg", avg_train_loss, epoch)
model.eval()
total_val_loss = 0.0
with torch.no_grad():
for idx, batch in enumerate(tqdm(val_loader, desc=f"Val Epoch {epoch}")):
mel = batch["mel"].to(DEVICE)
lengths = batch["lengths"].to(DEVICE)
nps = batch["nps"].to(DEVICE)
difficulty = batch["difficulty"].to(DEVICE)
level = batch["level"].to(DEVICE)
outputs = model(mel, lengths, nps, difficulty, level)
loss = criterion(outputs, batch)
total_val_loss += loss.item()
if idx < 3:
if mel.size(0) > 0:
pred_don = outputs["presence"][0, :, 0]
pred_ka = outputs["presence"][0, :, 1]
pred_drumroll = outputs["presence"][0, :, 2]
true_don = batch["don_labels"][0]
true_ka = batch["ka_labels"][0]
true_drumroll = batch["drumroll_labels"][0]
valid_length = batch["lengths"][0].item()
log_energy_plots_to_tensorboard(
writer,
f"Val_Sample_Batch_{idx}_Sample_0",
epoch,
pred_don,
pred_ka,
pred_drumroll,
true_don,
true_ka,
true_drumroll,
valid_length,
output_frame_hop_sec,
)
avg_val_loss = total_val_loss / len(val_loader)
writer.add_scalar("Loss/Val_Avg", avg_val_loss, epoch)
current_lr = optimizer.param_groups[0]["lr"]
writer.add_scalar("LR/learning_rate", current_lr, epoch)
if "nps" in batch:
writer.add_scalar(
"NPS/GT_Val_LastBatch_Avg", batch["nps"].mean().item(), epoch
)
print(
f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}"
)
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
pat_count = 0
torch.save(model.state_dict(), "best_model.pt")
print(f"Saved new best model to best_model.pt at epoch {epoch}")
else:
pat_count += 1
if pat_count >= patience:
print("Early stopping!")
break
writer.close()
model_id = "JacobLinCool/taiko-conformer-7"
try:
model.push_to_hub(
model_id, commit_message=f"Epoch {epoch}, Val Loss: {avg_val_loss:.4f}"
)
upload_folder(
repo_id=model_id,
folder_path="runs",
path_in_repo="runs",
commit_message="Upload TensorBoard logs",
)
except Exception as e:
print(f"Error uploading model or logs: {e}")
print("Make sure you have the correct permissions and try again.")
if __name__ == "__main__":
main()