Spaces:
Running
on
Zero
Running
on
Zero
from accelerate.utils import set_seed | |
set_seed(1024) | |
import math | |
import torch | |
from torch.utils.data import DataLoader | |
from torch.utils.tensorboard import SummaryWriter | |
from tqdm import tqdm | |
from datasets import concatenate_datasets | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from .config import ( | |
BATCH_SIZE, | |
DEVICE, | |
EPOCHS, | |
LR, | |
GRAD_ACCUM_STEPS, | |
HOP_LENGTH, | |
NPS_PENALTY_WEIGHT_ALPHA, | |
NPS_PENALTY_WEIGHT_BETA, | |
SAMPLE_RATE, | |
) | |
from .model import TaikoConformer7 | |
from .dataset import ds | |
from .preprocess import preprocess, collate_fn | |
from .loss import TaikoLoss | |
from huggingface_hub import upload_folder | |
def log_energy_plots_to_tensorboard( | |
writer, | |
tag_prefix, | |
epoch, | |
pred_don, | |
pred_ka, | |
pred_drumroll, | |
true_don, | |
true_ka, | |
true_drumroll, | |
valid_length, | |
hop_sec, | |
): | |
""" | |
Logs a plot of predicted vs. true energies for one sample to TensorBoard. | |
Energies should be 1D numpy arrays for the single sample, up to valid_length. | |
""" | |
pred_don = pred_don[:valid_length].detach().cpu().numpy() | |
pred_ka = pred_ka[:valid_length].detach().cpu().numpy() | |
pred_drumroll = pred_drumroll[:valid_length].detach().cpu().numpy() | |
true_don = true_don[:valid_length].cpu().numpy() | |
true_ka = true_ka[:valid_length].cpu().numpy() | |
true_drumroll = true_drumroll[:valid_length].cpu().numpy() | |
time_axis = np.arange(valid_length) * hop_sec | |
fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True) | |
fig.suptitle(f"{tag_prefix} - Epoch {epoch}", fontsize=16) | |
axs[0].plot(time_axis, true_don, label="True Don", color="blue", linestyle="--") | |
axs[0].plot(time_axis, pred_don, label="Pred Don", color="lightblue", alpha=0.8) | |
axs[0].set_ylabel("Don Energy") | |
axs[0].legend() | |
axs[0].grid(True) | |
axs[1].plot(time_axis, true_ka, label="True Ka", color="red", linestyle="--") | |
axs[1].plot(time_axis, pred_ka, label="Pred Ka", color="lightcoral", alpha=0.8) | |
axs[1].set_ylabel("Ka Energy") | |
axs[1].legend() | |
axs[1].grid(True) | |
axs[2].plot( | |
time_axis, true_drumroll, label="True Drumroll", color="green", linestyle="--" | |
) | |
axs[2].plot( | |
time_axis, pred_drumroll, label="Pred Drumroll", color="lightgreen", alpha=0.8 | |
) | |
axs[2].set_ylabel("Drumroll Energy") | |
axs[2].set_xlabel("Time (s)") | |
axs[2].legend() | |
axs[2].grid(True) | |
plt.tight_layout(rect=[0, 0, 1, 0.96]) | |
writer.add_figure(f"{tag_prefix}/Energy_Comparison", fig, epoch) | |
plt.close(fig) | |
def main(): | |
global ds | |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE | |
best_val_loss = float("inf") | |
patience = 10 | |
pat_count = 0 | |
ds_oni = ds.map( | |
preprocess, | |
remove_columns=ds.column_names, | |
fn_kwargs={"difficulty": "oni"}, | |
writer_batch_size=10, | |
) | |
ds_hard = ds.map( | |
preprocess, | |
remove_columns=ds.column_names, | |
fn_kwargs={"difficulty": "hard"}, | |
writer_batch_size=10, | |
) | |
ds_normal = ds.map( | |
preprocess, | |
remove_columns=ds.column_names, | |
fn_kwargs={"difficulty": "normal"}, | |
writer_batch_size=10, | |
) | |
ds = concatenate_datasets([ds_oni, ds_hard, ds_normal]) | |
ds_train_test = ds.train_test_split(test_size=0.1, seed=42) | |
train_loader = DataLoader( | |
ds_train_test["train"], | |
batch_size=BATCH_SIZE, | |
shuffle=True, | |
collate_fn=collate_fn, | |
num_workers=8, | |
persistent_workers=True, | |
prefetch_factor=4, | |
) | |
val_loader = DataLoader( | |
ds_train_test["test"], | |
batch_size=BATCH_SIZE, | |
shuffle=False, | |
collate_fn=collate_fn, | |
num_workers=8, | |
persistent_workers=True, | |
prefetch_factor=4, | |
) | |
model = TaikoConformer7().to(DEVICE) | |
optimizer = torch.optim.AdamW(model.parameters(), lr=LR) | |
criterion = TaikoLoss( | |
reduction="mean", | |
nps_penalty_weight_alpha=NPS_PENALTY_WEIGHT_ALPHA, | |
nps_penalty_weight_beta=NPS_PENALTY_WEIGHT_BETA, | |
).to(DEVICE) | |
num_optimizer_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS) | |
total_optimizer_steps = EPOCHS * num_optimizer_steps_per_epoch | |
scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
optimizer, max_lr=LR, total_steps=total_optimizer_steps | |
) | |
writer = SummaryWriter() | |
for epoch in range(1, EPOCHS + 1): | |
model.train() | |
total_epoch_loss = 0.0 | |
optimizer.zero_grad() | |
for idx, batch in enumerate(tqdm(train_loader, desc=f"Train Epoch {epoch}")): | |
mel = batch["mel"].to(DEVICE) | |
lengths = batch["lengths"].to(DEVICE) | |
nps = batch["nps"].to(DEVICE) | |
difficulty = batch["difficulty"].to(DEVICE) | |
level = batch["level"].to(DEVICE) | |
outputs = model(mel, lengths, nps, difficulty, level) | |
loss = criterion(outputs, batch) | |
total_epoch_loss += loss.item() | |
loss = loss / GRAD_ACCUM_STEPS | |
loss.backward() | |
if (idx + 1) % GRAD_ACCUM_STEPS == 0 or (idx + 1) == len(train_loader): | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
optimizer.step() | |
scheduler.step() | |
optimizer.zero_grad() | |
writer.add_scalar( | |
"Loss/Train_Step", | |
loss.item() * GRAD_ACCUM_STEPS, | |
epoch * len(train_loader) + idx, | |
) | |
writer.add_scalar( | |
"LR", scheduler.get_last_lr()[0], epoch * len(train_loader) + idx | |
) | |
if idx < 3: | |
if mel.size(0) > 0: | |
pred_don = outputs["presence"][0, :, 0] | |
pred_ka = outputs["presence"][0, :, 1] | |
pred_drumroll = outputs["presence"][0, :, 2] | |
true_don = batch["don_labels"][0] | |
true_ka = batch["ka_labels"][0] | |
true_drumroll = batch["drumroll_labels"][0] | |
valid_length = batch["lengths"][0].item() | |
log_energy_plots_to_tensorboard( | |
writer, | |
f"Train_Sample_Batch_{idx}_Sample_0", | |
epoch, | |
pred_don, | |
pred_ka, | |
pred_drumroll, | |
true_don, | |
true_ka, | |
true_drumroll, | |
valid_length, | |
output_frame_hop_sec, | |
) | |
avg_train_loss = total_epoch_loss / len(train_loader) | |
writer.add_scalar("Loss/Train_Avg", avg_train_loss, epoch) | |
model.eval() | |
total_val_loss = 0.0 | |
with torch.no_grad(): | |
for idx, batch in enumerate(tqdm(val_loader, desc=f"Val Epoch {epoch}")): | |
mel = batch["mel"].to(DEVICE) | |
lengths = batch["lengths"].to(DEVICE) | |
nps = batch["nps"].to(DEVICE) | |
difficulty = batch["difficulty"].to(DEVICE) | |
level = batch["level"].to(DEVICE) | |
outputs = model(mel, lengths, nps, difficulty, level) | |
loss = criterion(outputs, batch) | |
total_val_loss += loss.item() | |
if idx < 3: | |
if mel.size(0) > 0: | |
pred_don = outputs["presence"][0, :, 0] | |
pred_ka = outputs["presence"][0, :, 1] | |
pred_drumroll = outputs["presence"][0, :, 2] | |
true_don = batch["don_labels"][0] | |
true_ka = batch["ka_labels"][0] | |
true_drumroll = batch["drumroll_labels"][0] | |
valid_length = batch["lengths"][0].item() | |
log_energy_plots_to_tensorboard( | |
writer, | |
f"Val_Sample_Batch_{idx}_Sample_0", | |
epoch, | |
pred_don, | |
pred_ka, | |
pred_drumroll, | |
true_don, | |
true_ka, | |
true_drumroll, | |
valid_length, | |
output_frame_hop_sec, | |
) | |
avg_val_loss = total_val_loss / len(val_loader) | |
writer.add_scalar("Loss/Val_Avg", avg_val_loss, epoch) | |
current_lr = optimizer.param_groups[0]["lr"] | |
writer.add_scalar("LR/learning_rate", current_lr, epoch) | |
if "nps" in batch: | |
writer.add_scalar( | |
"NPS/GT_Val_LastBatch_Avg", batch["nps"].mean().item(), epoch | |
) | |
print( | |
f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}" | |
) | |
if avg_val_loss < best_val_loss: | |
best_val_loss = avg_val_loss | |
pat_count = 0 | |
torch.save(model.state_dict(), "best_model.pt") | |
print(f"Saved new best model to best_model.pt at epoch {epoch}") | |
else: | |
pat_count += 1 | |
if pat_count >= patience: | |
print("Early stopping!") | |
break | |
writer.close() | |
model_id = "JacobLinCool/taiko-conformer-7" | |
try: | |
model.push_to_hub( | |
model_id, commit_message=f"Epoch {epoch}, Val Loss: {avg_val_loss:.4f}" | |
) | |
upload_folder( | |
repo_id=model_id, | |
folder_path="runs", | |
path_in_repo="runs", | |
commit_message="Upload TensorBoard logs", | |
) | |
except Exception as e: | |
print(f"Error uploading model or logs: {e}") | |
print("Make sure you have the correct permissions and try again.") | |
if __name__ == "__main__": | |
main() | |