Spaces:
Running
Running
import argparse | |
import torch | |
import wandb | |
from torch import nn, optim | |
from torch.nn.functional import cosine_similarity | |
from torch.optim import lr_scheduler | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from typing_extensions import Optional | |
from src.dataset import RandomAugmentedDataset, get_byol_transforms | |
from src.models import BYOL | |
def get_data_loaders( | |
batch_size: int, | |
num_train_samples: int, | |
num_val_samples: int, | |
shape_params: Optional[dict] = None, | |
num_workers: int = 0 | |
): | |
augmentations = get_byol_transforms() | |
train_dataset = RandomAugmentedDataset( | |
augmentations, | |
shape_params, | |
num_samples=num_train_samples, | |
train=True | |
) | |
val_dataset = RandomAugmentedDataset( | |
augmentations, | |
shape_params, | |
num_samples=num_val_samples, | |
train=False | |
) | |
train_loader = DataLoader( | |
train_dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
num_workers=num_workers | |
) | |
val_loader = DataLoader( | |
val_dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=num_workers | |
) | |
return train_loader, val_loader | |
def build_model(lr: float): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = BYOL().to(device) | |
optimizer = optim.Adam( | |
list(model.online_network.parameters()) + list(model.online_predictor.parameters()), | |
lr=lr | |
) | |
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=2) | |
return model, optimizer, scheduler, device | |
def train_epoch( | |
model: nn.Module, | |
optimizer: optim.Optimizer, | |
train_loader: DataLoader, | |
device: torch.device | |
) -> dict: | |
model.train() | |
running_train_loss = 0.0 | |
total_cos_sim, total_l2_dist, total_feat_norm, total_grad_norm = 0.0, 0.0, 0.0, 0.0 | |
num_train_batches = 0 | |
for (view_1, view_2) in tqdm(train_loader, desc="Training"): | |
view_1 = view_1.to(device) | |
view_2 = view_2.to(device) | |
loss = model.loss(view_1, view_2) | |
optimizer.zero_grad() | |
loss.backward() | |
with torch.no_grad(): | |
online_proj1, target_proj1 = model(view_1) | |
online_proj2, target_proj2 = model(view_2) | |
cos_sim = cosine_similarity(online_proj1, target_proj2).mean().item() | |
l2_dist = torch.norm(online_proj1 - target_proj2, dim=-1).mean().item() | |
feat_norm = torch.norm(online_proj1, dim=-1).mean().item() | |
grad_norm = torch.norm( | |
torch.cat([ | |
p.grad.flatten() | |
for p in model.online_network.parameters() | |
if p.grad is not None | |
]) | |
).item() | |
total_cos_sim += cos_sim | |
total_l2_dist += l2_dist | |
total_feat_norm += feat_norm | |
total_grad_norm += grad_norm | |
optimizer.step() | |
model.soft_update_target_network() | |
running_train_loss += loss.item() | |
num_train_batches += 1 | |
train_loss = running_train_loss / num_train_batches | |
train_cos_sim = total_cos_sim / num_train_batches | |
train_l2_dist = total_l2_dist / num_train_batches | |
train_feat_norm = total_feat_norm / num_train_batches | |
train_grad_norm = total_grad_norm / num_train_batches | |
return { | |
"loss": train_loss, | |
"cos_sim": train_cos_sim, | |
"l2_dist": train_l2_dist, | |
"feat_norm": train_feat_norm, | |
"grad_norm": train_grad_norm, | |
} | |
def validate( | |
model: nn.Module, | |
val_loader: DataLoader, | |
device: torch.device | |
) -> dict: | |
model.eval() | |
running_val_loss = 0.0 | |
total_cos_sim, total_l2_dist, total_feat_norm = 0.0, 0.0, 0.0 | |
num_val_batches = 0 | |
for (view_1, view_2) in tqdm(val_loader, desc="Validation"): | |
view_1 = view_1.to(device) | |
view_2 = view_2.to(device) | |
loss = model.loss(view_1, view_2) | |
running_val_loss += loss.item() | |
online_proj1, target_proj1 = model(view_1) | |
online_proj2, target_proj2 = model(view_2) | |
cos_sim = cosine_similarity(online_proj1, target_proj2).mean().item() | |
l2_dist = torch.norm(online_proj1 - target_proj2, dim=-1).mean().item() | |
feat_norm = torch.norm(online_proj1, dim=-1).mean().item() | |
total_cos_sim += cos_sim | |
total_l2_dist += l2_dist | |
total_feat_norm += feat_norm | |
num_val_batches += 1 | |
val_loss = running_val_loss / num_val_batches | |
val_cos_sim = total_cos_sim / num_val_batches | |
val_l2_dist = total_l2_dist / num_val_batches | |
val_feat_norm = total_feat_norm / num_val_batches | |
return { | |
"loss": val_loss, | |
"cos_sim": val_cos_sim, | |
"l2_dist": val_l2_dist, | |
"feat_norm": val_feat_norm | |
} | |
def train( | |
model: nn.Module, | |
optimizer: optim.Optimizer, | |
scheduler, | |
device: torch.device, | |
train_loader: DataLoader, | |
val_loader: DataLoader, | |
num_epochs: int, | |
early_stopping_patience: int = 3, | |
save_path: str = "best_byol.pth" | |
): | |
best_loss = float("inf") | |
epochs_no_improve = 0 | |
print("Start training...") | |
for epoch in range(num_epochs): | |
print(f"Epoch {epoch + 1}/{num_epochs}") | |
train_metrics = train_epoch(model, optimizer, train_loader, device) | |
val_metrics = validate(model, val_loader, device) | |
wandb.log({ | |
"epoch": epoch + 1, | |
"train_loss": train_metrics["loss"], | |
"train_cos_sim": train_metrics["cos_sim"], | |
"train_l2_dist": train_metrics["l2_dist"], | |
"train_feat_norm": train_metrics["feat_norm"], | |
"train_grad_norm": train_metrics["grad_norm"], | |
"val_loss": val_metrics["loss"], | |
"val_cos_sim": val_metrics["cos_sim"], | |
"val_l2_dist": val_metrics["l2_dist"], | |
"val_feat_norm": val_metrics["feat_norm"], | |
}) | |
print( | |
f"Train Loss: {train_metrics['loss']:.4f} | " | |
f"CosSim: {train_metrics['cos_sim']:.4f} | " | |
f"L2Dist: {train_metrics['l2_dist']:.4f}" | |
) | |
print( | |
f"Val Loss: {val_metrics['loss']:.4f} | " | |
f"CosSim: {val_metrics['cos_sim']:.4f} | " | |
f"L2Dist: {val_metrics['l2_dist']:.4f}" | |
) | |
current_val_loss = val_metrics["loss"] | |
if current_val_loss < best_loss or val_metrics['cos_sim'] >= 0.86: | |
best_loss = current_val_loss | |
encoder_state_dict = model.online_network.encoder.state_dict() | |
torch.save(encoder_state_dict, save_path) | |
epochs_no_improve = 0 | |
else: | |
epochs_no_improve += 1 | |
scheduler.step(val_metrics["cos_sim"]) | |
if epochs_no_improve >= early_stopping_patience: | |
print(f"Early stopping on epoch {epoch + 1}") | |
break | |
def main(config: dict): | |
wandb.init(project="contrastive_learning_byol", config=config) | |
train_loader, val_loader = get_data_loaders( | |
batch_size=config["batch_size"], | |
num_train_samples=config["num_train_samples"], | |
num_val_samples=config["num_val_samples"], | |
shape_params=config["shape_params"] | |
) | |
model, optimizer, scheduler, device = build_model( | |
lr=config["lr"] | |
) | |
train( | |
model=model, | |
optimizer=optimizer, | |
scheduler=scheduler, | |
device=device, | |
train_loader=train_loader, | |
val_loader=val_loader, | |
num_epochs=config["num_epochs"], | |
early_stopping_patience=config["early_stopping_patience"], | |
save_path=config["save_path"] | |
) | |
wandb.finish() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Train BYOL model") | |
parser.add_argument("--batch_size", type=int, default=512) | |
parser.add_argument("--lr", type=float, default=5e-4) | |
parser.add_argument("--num_epochs", type=int, default=15) | |
parser.add_argument("--num_train_samples", type=int, default=100000) | |
parser.add_argument("--num_val_samples", type=int, default=10000) | |
parser.add_argument("--random_intensity", type=int, default=1) | |
parser.add_argument("--early_stopping_patience", type=int, default=3) | |
parser.add_argument("--save_path", type=str, default="best_byol.pth") | |
args = parser.parse_args() | |
config = { | |
"batch_size": args.batch_size, | |
"lr": args.lr, | |
"num_epochs": args.num_epochs, | |
"num_train_samples": args.num_train_samples, | |
"num_val_samples": args.num_val_samples, | |
"shape_params": { | |
"random_intensity": bool(args.random_intensity) | |
}, | |
"early_stopping_patience": args.early_stopping_patience, | |
"save_path": args.save_path | |
} | |
main(config) | |