tc5-exp / tc5 /train.py
JacobLinCool's picture
Implement TaikoConformer7 model, loss function, preprocessing, and training pipeline
812b01c
from accelerate.utils import set_seed
set_seed(1024)
import math
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from datasets import concatenate_datasets
import matplotlib.pyplot as plt
import numpy as np
from .config import (
BATCH_SIZE,
DEVICE,
EPOCHS,
LR,
GRAD_ACCUM_STEPS,
HOP_LENGTH,
SAMPLE_RATE,
)
from .model import TaikoConformer5
from .dataset import ds
from .preprocess import preprocess, collate_fn
from .loss import TaikoEnergyLoss
from huggingface_hub import upload_folder
# --- Helper function to log energy plots ---
def log_energy_plots_to_tensorboard(
writer,
tag_prefix,
epoch,
pred_don,
pred_ka,
pred_drumroll,
true_don,
true_ka,
true_drumroll,
valid_length, # Actual valid length of the sequence (before padding)
hop_sec,
):
"""
Logs a plot of predicted vs. true energies for one sample to TensorBoard.
Energies should be 1D numpy arrays for the single sample, up to valid_length.
"""
# Ensure data is on CPU and converted to numpy, and select only the valid part
pred_don = pred_don[:valid_length].detach().cpu().numpy()
pred_ka = pred_ka[:valid_length].detach().cpu().numpy()
pred_drumroll = pred_drumroll[:valid_length].detach().cpu().numpy()
true_don = true_don[:valid_length].cpu().numpy()
true_ka = true_ka[:valid_length].cpu().numpy()
true_drumroll = true_drumroll[:valid_length].cpu().numpy()
time_axis = np.arange(valid_length) * hop_sec
fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True)
fig.suptitle(f"{tag_prefix} - Epoch {epoch}", fontsize=16)
axs[0].plot(time_axis, true_don, label="True Don", color="blue", linestyle="--")
axs[0].plot(time_axis, pred_don, label="Pred Don", color="lightblue", alpha=0.8)
axs[0].set_ylabel("Don Energy")
axs[0].legend()
axs[0].grid(True)
axs[1].plot(time_axis, true_ka, label="True Ka", color="red", linestyle="--")
axs[1].plot(time_axis, pred_ka, label="Pred Ka", color="lightcoral", alpha=0.8)
axs[1].set_ylabel("Ka Energy")
axs[1].legend()
axs[1].grid(True)
axs[2].plot(
time_axis, true_drumroll, label="True Drumroll", color="green", linestyle="--"
)
axs[2].plot(
time_axis, pred_drumroll, label="Pred Drumroll", color="lightgreen", alpha=0.8
)
axs[2].set_ylabel("Drumroll Energy")
axs[2].set_xlabel("Time (s)")
axs[2].legend()
axs[2].grid(True)
plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make space for suptitle
writer.add_figure(f"{tag_prefix}/Energy_Comparison", fig, epoch)
plt.close(fig)
def main():
global ds
# Calculate hop seconds for model output frames
# This assumes the model output time dimension corresponds to the mel spectrogram time dimension
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
best_val_loss = float("inf")
patience = 10 # Increased patience a bit
pat_count = 0
ds_oni = ds.map(
preprocess,
remove_columns=ds.column_names,
fn_kwargs={"difficulty": "oni"},
writer_batch_size=10,
)
ds_hard = ds.map(
preprocess,
remove_columns=ds.column_names,
fn_kwargs={"difficulty": "hard"},
writer_batch_size=10,
)
ds_normal = ds.map(
preprocess,
remove_columns=ds.column_names,
fn_kwargs={"difficulty": "normal"},
writer_batch_size=10,
)
ds = concatenate_datasets([ds_oni, ds_hard, ds_normal])
ds_train_test = ds.train_test_split(test_size=0.1, seed=42)
train_loader = DataLoader(
ds_train_test["train"],
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=collate_fn,
num_workers=2,
)
val_loader = DataLoader(
ds_train_test["test"],
batch_size=BATCH_SIZE,
shuffle=False,
collate_fn=collate_fn,
num_workers=2,
)
model = TaikoConformer5().to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
criterion = TaikoEnergyLoss(reduction="mean").to(DEVICE)
# Adjust scheduler steps for gradient accumulation
num_optimizer_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
total_optimizer_steps = EPOCHS * num_optimizer_steps_per_epoch
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=LR, total_steps=total_optimizer_steps
)
writer = SummaryWriter()
for epoch in range(1, EPOCHS + 1):
model.train()
total_epoch_loss = 0.0
optimizer.zero_grad()
for idx, batch in enumerate(tqdm(train_loader, desc=f"Train Epoch {epoch}")):
mel = batch["mel"].to(DEVICE)
# Unpack new energy-based labels
don_labels = batch["don_labels"].to(DEVICE)
ka_labels = batch["ka_labels"].to(DEVICE)
drumroll_labels = batch["drumroll_labels"].to(DEVICE)
lengths = batch["lengths"].to(
DEVICE
) # These are for the Conformer model output
nps = batch["nps"].to(DEVICE)
output_dict = model(mel, lengths, nps)
# output_dict["presence"] is now (B, T_out, 3) for don, ka, drumroll energies
pred_energies_batch = output_dict["presence"] # (B, T_out, 3)
loss_input_batch = {
"don_labels": don_labels,
"ka_labels": ka_labels,
"drumroll_labels": drumroll_labels,
"lengths": lengths, # Pass lengths for masking within the loss function
}
loss = criterion(output_dict, loss_input_batch)
(loss / GRAD_ACCUM_STEPS).backward()
if (idx + 1) % GRAD_ACCUM_STEPS == 0 or (idx + 1) == len(train_loader):
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
total_epoch_loss += loss.item()
# Log plot for the first sample of the first batch in each training epoch
if idx == 0:
first_sample_pred_don = pred_energies_batch[0, :, 0]
first_sample_pred_ka = pred_energies_batch[0, :, 1]
first_sample_pred_drumroll = pred_energies_batch[0, :, 2]
first_sample_true_don = don_labels[0, :]
first_sample_true_ka = ka_labels[0, :]
first_sample_true_drumroll = drumroll_labels[0, :]
first_sample_length = lengths[
0
].item() # Get the valid length of the first sample
log_energy_plots_to_tensorboard(
writer,
"Train/Sample_0",
epoch,
first_sample_pred_don,
first_sample_pred_ka,
first_sample_pred_drumroll,
first_sample_true_don,
first_sample_true_ka,
first_sample_true_drumroll,
first_sample_length,
output_frame_hop_sec,
)
avg_train_loss = total_epoch_loss / len(train_loader)
writer.add_scalar("Loss/Train_Avg", avg_train_loss, epoch)
# Validation
model.eval()
total_val_loss = 0.0
# Removed storage for classification logits/labels and confusion matrix components
with torch.no_grad():
for val_idx, batch in enumerate(
tqdm(val_loader, desc=f"Val Epoch {epoch}")
):
mel = batch["mel"].to(DEVICE)
don_labels = batch["don_labels"].to(DEVICE)
ka_labels = batch["ka_labels"].to(DEVICE)
drumroll_labels = batch["drumroll_labels"].to(DEVICE)
lengths = batch["lengths"].to(DEVICE)
nps = batch["nps"].to(DEVICE) # Ground truth NPS from batch
output_dict = model(mel, lengths, nps)
pred_energies_val_batch = output_dict["presence"] # (B, T_out, 3)
val_loss_input_batch = {
"don_labels": don_labels,
"ka_labels": ka_labels,
"drumroll_labels": drumroll_labels,
"lengths": lengths,
}
val_loss = criterion(output_dict, val_loss_input_batch)
total_val_loss += val_loss.item()
# Log plot for the first sample of the first batch in each validation epoch
if val_idx == 0:
first_val_sample_pred_don = pred_energies_val_batch[0, :, 0]
first_val_sample_pred_ka = pred_energies_val_batch[0, :, 1]
first_val_sample_pred_drumroll = pred_energies_val_batch[0, :, 2]
first_val_sample_true_don = don_labels[0, :]
first_val_sample_true_ka = ka_labels[0, :]
first_val_sample_true_drumroll = drumroll_labels[0, :]
first_val_sample_length = lengths[0].item()
log_energy_plots_to_tensorboard(
writer,
"Eval/Sample_0",
epoch,
first_val_sample_pred_don,
first_val_sample_pred_ka,
first_val_sample_pred_drumroll,
first_val_sample_true_don,
first_val_sample_true_ka,
first_val_sample_true_drumroll,
first_val_sample_length,
output_frame_hop_sec,
)
# Log ground truth NPS for reference during validation if needed
# writer.add_scalar("NPS/GT_Val_Batch_Avg", nps.mean().item(), epoch * len(val_loader) + idx)
avg_val_loss = total_val_loss / len(val_loader)
writer.add_scalar("Loss/Val_Avg", avg_val_loss, epoch)
# Log learning rate
current_lr = optimizer.param_groups[0]["lr"]
writer.add_scalar("LR/learning_rate", current_lr, epoch)
# Log ground truth NPS from the last validation batch (or mean over epoch)
if "nps" in batch: # Check if nps is in the last batch
writer.add_scalar(
"NPS/GT_Val_LastBatch_Avg", batch["nps"].mean().item(), epoch
)
print(
f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}"
)
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
pat_count = 0
torch.save(model.state_dict(), "best_model.pt") # Changed model save name
print(f"Saved new best model to best_model.pt at epoch {epoch}")
else:
pat_count += 1
if pat_count >= patience:
print("Early stopping!")
break
writer.close()
model_id = "JacobLinCool/taiko-conformer-5"
try:
model.push_to_hub(model_id, commit_message="Upload trained model")
upload_folder(
repo_id=model_id,
folder_path="runs",
path_in_repo=".",
commit_message="Upload training logs",
ignore_patterns=["*.txt", "*.json", "*.csv"],
)
print(f"Model and logs uploaded to {model_id}")
except Exception as e:
print(f"Error uploading to Hugging Face Hub: {e}")
print("Make sure you have the correct permissions and try again.")
if __name__ == "__main__":
main()