cels / train_byol.py
alexandraroze's picture
fixed config
b265c62
import argparse
import torch
import wandb
from torch import nn, optim
from torch.nn.functional import cosine_similarity
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing_extensions import Optional
from src.dataset import RandomAugmentedDataset, get_byol_transforms
from src.models import BYOL
def get_data_loaders(
batch_size: int,
num_train_samples: int,
num_val_samples: int,
shape_params: Optional[dict] = None,
num_workers: int = 0
):
augmentations = get_byol_transforms()
train_dataset = RandomAugmentedDataset(
augmentations,
shape_params,
num_samples=num_train_samples,
train=True
)
val_dataset = RandomAugmentedDataset(
augmentations,
shape_params,
num_samples=num_val_samples,
train=False
)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers
)
return train_loader, val_loader
def build_model(lr: float):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BYOL().to(device)
optimizer = optim.Adam(
list(model.online_network.parameters()) + list(model.online_predictor.parameters()),
lr=lr
)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=2)
return model, optimizer, scheduler, device
def train_epoch(
model: nn.Module,
optimizer: optim.Optimizer,
train_loader: DataLoader,
device: torch.device
) -> dict:
model.train()
running_train_loss = 0.0
total_cos_sim, total_l2_dist, total_feat_norm, total_grad_norm = 0.0, 0.0, 0.0, 0.0
num_train_batches = 0
for (view_1, view_2) in tqdm(train_loader, desc="Training"):
view_1 = view_1.to(device)
view_2 = view_2.to(device)
loss = model.loss(view_1, view_2)
optimizer.zero_grad()
loss.backward()
with torch.no_grad():
online_proj1, target_proj1 = model(view_1)
online_proj2, target_proj2 = model(view_2)
cos_sim = cosine_similarity(online_proj1, target_proj2).mean().item()
l2_dist = torch.norm(online_proj1 - target_proj2, dim=-1).mean().item()
feat_norm = torch.norm(online_proj1, dim=-1).mean().item()
grad_norm = torch.norm(
torch.cat([
p.grad.flatten()
for p in model.online_network.parameters()
if p.grad is not None
])
).item()
total_cos_sim += cos_sim
total_l2_dist += l2_dist
total_feat_norm += feat_norm
total_grad_norm += grad_norm
optimizer.step()
model.soft_update_target_network()
running_train_loss += loss.item()
num_train_batches += 1
train_loss = running_train_loss / num_train_batches
train_cos_sim = total_cos_sim / num_train_batches
train_l2_dist = total_l2_dist / num_train_batches
train_feat_norm = total_feat_norm / num_train_batches
train_grad_norm = total_grad_norm / num_train_batches
return {
"loss": train_loss,
"cos_sim": train_cos_sim,
"l2_dist": train_l2_dist,
"feat_norm": train_feat_norm,
"grad_norm": train_grad_norm,
}
@torch.no_grad()
def validate(
model: nn.Module,
val_loader: DataLoader,
device: torch.device
) -> dict:
model.eval()
running_val_loss = 0.0
total_cos_sim, total_l2_dist, total_feat_norm = 0.0, 0.0, 0.0
num_val_batches = 0
for (view_1, view_2) in tqdm(val_loader, desc="Validation"):
view_1 = view_1.to(device)
view_2 = view_2.to(device)
loss = model.loss(view_1, view_2)
running_val_loss += loss.item()
online_proj1, target_proj1 = model(view_1)
online_proj2, target_proj2 = model(view_2)
cos_sim = cosine_similarity(online_proj1, target_proj2).mean().item()
l2_dist = torch.norm(online_proj1 - target_proj2, dim=-1).mean().item()
feat_norm = torch.norm(online_proj1, dim=-1).mean().item()
total_cos_sim += cos_sim
total_l2_dist += l2_dist
total_feat_norm += feat_norm
num_val_batches += 1
val_loss = running_val_loss / num_val_batches
val_cos_sim = total_cos_sim / num_val_batches
val_l2_dist = total_l2_dist / num_val_batches
val_feat_norm = total_feat_norm / num_val_batches
return {
"loss": val_loss,
"cos_sim": val_cos_sim,
"l2_dist": val_l2_dist,
"feat_norm": val_feat_norm
}
def train(
model: nn.Module,
optimizer: optim.Optimizer,
scheduler,
device: torch.device,
train_loader: DataLoader,
val_loader: DataLoader,
num_epochs: int,
early_stopping_patience: int = 3,
save_path: str = "best_byol.pth"
):
best_loss = float("inf")
epochs_no_improve = 0
print("Start training...")
for epoch in range(num_epochs):
print(f"Epoch {epoch + 1}/{num_epochs}")
train_metrics = train_epoch(model, optimizer, train_loader, device)
val_metrics = validate(model, val_loader, device)
wandb.log({
"epoch": epoch + 1,
"train_loss": train_metrics["loss"],
"train_cos_sim": train_metrics["cos_sim"],
"train_l2_dist": train_metrics["l2_dist"],
"train_feat_norm": train_metrics["feat_norm"],
"train_grad_norm": train_metrics["grad_norm"],
"val_loss": val_metrics["loss"],
"val_cos_sim": val_metrics["cos_sim"],
"val_l2_dist": val_metrics["l2_dist"],
"val_feat_norm": val_metrics["feat_norm"],
})
print(
f"Train Loss: {train_metrics['loss']:.4f} | "
f"CosSim: {train_metrics['cos_sim']:.4f} | "
f"L2Dist: {train_metrics['l2_dist']:.4f}"
)
print(
f"Val Loss: {val_metrics['loss']:.4f} | "
f"CosSim: {val_metrics['cos_sim']:.4f} | "
f"L2Dist: {val_metrics['l2_dist']:.4f}"
)
current_val_loss = val_metrics["loss"]
if current_val_loss < best_loss or val_metrics['cos_sim'] >= 0.86:
best_loss = current_val_loss
encoder_state_dict = model.online_network.encoder.state_dict()
torch.save(encoder_state_dict, save_path)
epochs_no_improve = 0
else:
epochs_no_improve += 1
scheduler.step(val_metrics["cos_sim"])
if epochs_no_improve >= early_stopping_patience:
print(f"Early stopping on epoch {epoch + 1}")
break
def main(config: dict):
wandb.init(project="contrastive_learning_byol", config=config)
train_loader, val_loader = get_data_loaders(
batch_size=config["batch_size"],
num_train_samples=config["num_train_samples"],
num_val_samples=config["num_val_samples"],
shape_params=config["shape_params"]
)
model, optimizer, scheduler, device = build_model(
lr=config["lr"]
)
train(
model=model,
optimizer=optimizer,
scheduler=scheduler,
device=device,
train_loader=train_loader,
val_loader=val_loader,
num_epochs=config["num_epochs"],
early_stopping_patience=config["early_stopping_patience"],
save_path=config["save_path"]
)
wandb.finish()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train BYOL model")
parser.add_argument("--batch_size", type=int, default=512)
parser.add_argument("--lr", type=float, default=5e-4)
parser.add_argument("--num_epochs", type=int, default=15)
parser.add_argument("--num_train_samples", type=int, default=100000)
parser.add_argument("--num_val_samples", type=int, default=10000)
parser.add_argument("--random_intensity", type=int, default=1)
parser.add_argument("--early_stopping_patience", type=int, default=3)
parser.add_argument("--save_path", type=str, default="best_byol.pth")
args = parser.parse_args()
config = {
"batch_size": args.batch_size,
"lr": args.lr,
"num_epochs": args.num_epochs,
"num_train_samples": args.num_train_samples,
"num_val_samples": args.num_val_samples,
"shape_params": {
"random_intensity": bool(args.random_intensity)
},
"early_stopping_patience": args.early_stopping_patience,
"save_path": args.save_path
}
main(config)