import argparse import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.optim as optim import wandb from torch.optim.lr_scheduler import StepLR from torch.utils.data import DataLoader from tqdm import tqdm from typing_extensions import Optional from src.dataset import RandomPairDataset from src.models import CrossAttentionClassifier, VGGLikeEncode def visualize_attention(attn_heatmap, epoch: int): fig, ax = plt.subplots(figsize=(6, 6)) im = ax.imshow(attn_heatmap, cmap="hot", interpolation="nearest") plt.colorbar(im, fraction=0.046, pad=0.04) plt.title(f"Attention Heatmap (Flatten 64x64) | Epoch {epoch}") wandb.log({"Flatten Attention Heatmap": wandb.Image(fig, caption=f"Flatten 64x64 | Epoch {epoch}")}) plt.close(fig) def get_data_loaders( num_train_samples: int, num_val_samples: int, batch_size: int, num_workers: int = 0, shape_params: Optional[dict] = None, ): train_dataset = RandomPairDataset( shape_params=shape_params, num_samples=num_train_samples, train=True ) val_dataset = RandomPairDataset( shape_params=shape_params, num_samples=num_val_samples, train=False ) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers ) return train_loader, val_loader def build_model( path_to_encoder: str, lr: float, weight_decay: float, step_size: int, gamma: float, device: torch.device ): encoder = VGGLikeEncode(in_channels=1, out_channels=128, feature_dim=32, apply_pooling=False) encoder.load_state_dict(torch.load(path_to_encoder)) model = CrossAttentionClassifier(encoder=encoder) model = model.to(device) criterion = nn.BCEWithLogitsLoss() optimizer = optim.Adam( model.parameters(), lr=lr, weight_decay=weight_decay ) scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma) return model, criterion, optimizer, scheduler def train_epoch( model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer, train_loader: DataLoader, device: torch.device ): model.train() running_loss = 0.0 correct = 0 total = 0 for img1, img2, labels in tqdm(train_loader, desc="Training", leave=False): img1, img2, labels = img1.to(device), img2.to(device), labels.to(device) optimizer.zero_grad() logits, attn_weights = model(img1, img2) loss = criterion(logits, labels) loss.backward() optimizer.step() running_loss += loss.item() * img1.size(0) preds = (torch.sigmoid(logits) > 0.5).float() correct += (preds == labels).sum().item() total += labels.size(0) epoch_loss = running_loss / len(train_loader.dataset) epoch_acc = correct / total return epoch_loss, epoch_acc @torch.no_grad() def validate( model: nn.Module, criterion: nn.Module, val_loader: DataLoader, device: torch.device ): model.eval() running_loss = 0.0 correct = 0 total = 0 for img1, img2, labels in tqdm(val_loader, desc="Validation", leave=False): img1, img2, labels = img1.to(device), img2.to(device), labels.to(device) logits, attn_weights = model(img1, img2) loss = criterion(logits, labels) running_loss += loss.item() * img1.size(0) preds = (torch.sigmoid(logits) > 0.5).float() correct += (preds == labels).sum().item() total += labels.size(0) epoch_loss = running_loss / len(val_loader.dataset) epoch_acc = correct / total return epoch_loss, epoch_acc def train( model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer, scheduler, train_loader: DataLoader, val_loader: DataLoader, device: torch.device, num_epochs: int = 30, save_path: str = "best_attention_classifier.pth" ): best_val_loss = float("inf") epochs_no_improve = 0 print("Start training...") for epoch in range(num_epochs): print(f"Epoch {epoch + 1}/{num_epochs}") train_loss, train_acc = train_epoch(model, criterion, optimizer, train_loader, device) val_loss, val_acc = validate(model, criterion, val_loader, device) scheduler.step() wandb.log({ "epoch": epoch + 1, "train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc, "lr": optimizer.param_groups[0]["lr"], }) print( f"learning rate: {optimizer.param_groups[0]['lr']:.6f}, " f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, " f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}" ) if val_loss < best_val_loss: best_val_loss = val_loss torch.save(model.state_dict(), save_path) epochs_no_improve = 0 else: epochs_no_improve += 1 with torch.no_grad(): sample_img1, sample_img2, sample_labels = next(iter(val_loader)) sample_img1, sample_img2 = sample_img1.to(device), sample_img2.to(device) _, sample_attn_weights = model(sample_img1, sample_img2) wandb.log({ "attention_std": sample_attn_weights.std().item(), "attention_mean": sample_attn_weights.mean().item(), }) attn_heatmap = sample_attn_weights[0].detach().cpu().numpy() visualize_attention(attn_heatmap, epoch) def main(config): wandb.init(project="cross_attention_classifier", config=config) train_loader, val_loader = get_data_loaders( shape_params=config["shape_params"], num_train_samples=config["num_train_samples"], num_val_samples=config["num_val_samples"], batch_size=config["batch_size"] ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model, criterion, optimizer, scheduler = build_model( path_to_encoder=config["path_to_encoder"], lr=config["lr"], weight_decay=config["weight_decay"], step_size=config["step_size"], gamma=config["gamma"], device=device ) train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, train_loader=train_loader, val_loader=val_loader, device=device, num_epochs=config["num_epochs"], save_path=config["save_path"] ) wandb.finish() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train classifier model") parser.add_argument("--path_to_encoder", type=str, default="best_byol.pth") parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--lr", type=float, default=8e-5) parser.add_argument("--weight_decay", type=float, default=1e-4) parser.add_argument("--step_size", type=int, default=10) parser.add_argument("--gamma", type=float, default=0.1) parser.add_argument("--num_epochs", type=int, default=10) parser.add_argument("--num_train_samples", type=int, default=10000) parser.add_argument("--num_val_samples", type=int, default=2000) parser.add_argument("--save_path", type=str, default="best_attention_classifier.pth") args = parser.parse_args() config = { "path_to_encoder": args.path_to_encoder, "batch_size": args.batch_size, "lr": args.lr, "weight_decay": args.weight_decay, "step_size": args.step_size, "gamma": args.gamma, "num_epochs": args.num_epochs, "num_train_samples": args.num_train_samples, "num_val_samples": args.num_val_samples, "save_path": args.save_path, } if "shape_params" not in config: config["shape_params"] = {} main(config)