|
''' |
|
Demo training script for Feature Selection Gates (FSG) with ViT on MNIST test set |
|
|
|
This is a minimal demo: we train only on the MNIST test set (resized and converted to 3-channel) |
|
for a few epochs to simulate training, save the checkpoint, and allow downstream inference. |
|
|
|
Paper: |
|
https://papers.miccai.org/miccai-2024/316-Paper0410.html |
|
Code: |
|
https://github.com/cosmoimd/feature-selection-gates |
|
Contact: |
|
giorgio.roffo@gmail.com |
|
''' |
|
|
|
import os |
|
import warnings |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import psutil |
|
from tqdm import tqdm |
|
from torchvision import transforms |
|
from torchvision.datasets import MNIST |
|
from torchvision.models import vit_b_16, ViT_B_16_Weights |
|
from torch.utils.data import DataLoader |
|
from vit_with_fsg import vit_with_fsg |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"\nπ₯οΈ Using device: {device}") |
|
if device.type == "cuda": |
|
print(f"π CUDA device: {torch.cuda.get_device_name(0)}") |
|
print(f"πΎ GPU memory total: {torch.cuda.get_device_properties(0).total_memory / (1024 ** 3):.2f} GB") |
|
print(f"π§ System RAM: {psutil.virtual_memory().total / (1024 ** 3):.2f} GB") |
|
|
|
|
|
print("\nπ Loading MNIST demo set for demo training (resized to 224x224, 3-channel)...") |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.Grayscale(num_output_channels=3), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5,), (0.5,)) |
|
]) |
|
|
|
dataset = MNIST(root="./data", train=False, download=True, transform=transform) |
|
dataloader = DataLoader(dataset, batch_size=32, shuffle=True) |
|
|
|
|
|
print("\nπ₯ Loading pretrained ViT backbone from torchvision...") |
|
backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT) |
|
model = vit_with_fsg(backbone).to(device) |
|
|
|
|
|
fsg_params = [] |
|
base_params = [] |
|
for name, param in model.named_parameters(): |
|
if 'fsag_rgb_ls' in name: |
|
fsg_params.append(param) |
|
else: |
|
base_params.append(param) |
|
|
|
|
|
lr_base = 1e-4 |
|
lr_fsg = 5e-4 |
|
print(f"\nπ§ Optimizer setup:") |
|
print(f" πΉ Base ViT parameters LR: {lr_base}") |
|
print(f" πΈ FSG parameters LR: {lr_fsg}") |
|
|
|
optimizer = optim.AdamW([ |
|
{"params": base_params, "lr": lr_base}, |
|
{"params": fsg_params, "lr": lr_fsg} |
|
]) |
|
|
|
criterion = nn.CrossEntropyLoss() |
|
epochs = 3 |
|
print(f"\nπ Starting demo training for {epochs} epochs...") |
|
|
|
model.train() |
|
for epoch in range(epochs): |
|
steps_demo = 0 |
|
running_loss = 0.0 |
|
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", ncols=100) |
|
for inputs, targets in pbar: |
|
if steps_demo > 25: |
|
break |
|
steps_demo += 1 |
|
inputs, targets = inputs.to(device), targets.to(device) |
|
optimizer.zero_grad() |
|
outputs = model(inputs) |
|
loss = criterion(outputs, targets) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
running_loss += loss.item() |
|
pbar.set_postfix({"loss": running_loss / (pbar.n + 1e-8)}) |
|
|
|
print("\nβ
Training complete.") |
|
|
|
|
|
ckpt_dir = "./checkpoints" |
|
os.makedirs(ckpt_dir, exist_ok=True) |
|
ckpt_path = os.path.join(ckpt_dir, "fsg_vit_mnist_demo.pth") |
|
torch.save(model.state_dict(), ckpt_path) |
|
print(f"πΎ Checkpoint saved to: {ckpt_path}") |
|
|