File size: 3,719 Bytes
8573586
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
'''
Demo training script for Feature Selection Gates (FSG) with ViT on Imagenette

This script loads the Imagenette dataset (ImageNet-mini),
trains a ViT model augmented with FSG, and saves the model checkpoint.

Paper:
https://papers.miccai.org/miccai-2024/316-Paper0410.html
Code:
https://github.com/cosmoimd/feature-selection-gates
Contact:
giorgio.roffo@gmail.com
'''

import os
import tarfile
import urllib.request
import torch
import torch.nn as nn
import torch.optim as optim
import psutil
from tqdm import tqdm
from torchvision import transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from vit_with_fsg import vit_with_fsg

# System info
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nπŸ–₯️  Using device: {device}")
if device.type == "cuda":
    print(f"πŸš€ CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"πŸ’Ύ GPU memory total: {torch.cuda.get_device_properties(0).total_memory / (1024 ** 3):.2f} GB")
print(f"🧠 System RAM: {psutil.virtual_memory().total / (1024 ** 3):.2f} GB")

# Dataset path
imagenette_path = "./imagenette2-160/val"
if not os.path.exists(imagenette_path):
    print("πŸ“¦ Downloading Imagenette...")
    url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz"
    tgz_path = "imagenette2-160.tgz"
    urllib.request.urlretrieve(url, tgz_path)
    print("πŸ“‚ Extracting Imagenette dataset...")
    with tarfile.open(tgz_path, "r:gz") as tar:
        tar.extractall()
    os.remove(tgz_path)
    print("βœ… Dataset ready.")

# Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

# Dataset and loader
dataset = ImageFolder(root=imagenette_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Model setup
print("\nπŸ“₯ Loading pretrained ViT backbone from torchvision...")
backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
model = vit_with_fsg(backbone).to(device)

# Optimizer with separate LRs for FSG and base ViT
fsg_params, base_params = [], []
for name, param in model.named_parameters():
    if 'fsag_rgb_ls' in name:
        fsg_params.append(param)
    else:
        base_params.append(param)

lr_base = 1e-4
lr_fsg = 5e-4
print(f"\nπŸ”§ Optimizer setup:")
print(f"   πŸ”Ή Base ViT parameters LR: {lr_base}")
print(f"   πŸ”Έ FSG parameters LR: {lr_fsg}")

optimizer = optim.AdamW([
    {"params": base_params, "lr": lr_base},
    {"params": fsg_params, "lr": lr_fsg}
])
criterion = nn.CrossEntropyLoss()

# Training loop
epochs = 3
print(f"\nπŸš€ Starting demo training for {epochs} epochs...")
model.train()
for epoch in range(epochs):
    steps_demo = 0  # to remove: for demo only
    running_loss = 0.0
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", ncols=100)
    for inputs, targets in pbar:
        if steps_demo > 25: # to remove: for demo only
            break # to remove: for demo only
        steps_demo += 1  # to remove: for demo only
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        pbar.set_postfix({"loss": running_loss / (pbar.n + 1e-8)})

print("\nβœ… Training complete.")

# Save checkpoint
ckpt_dir = "./checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
ckpt_path = os.path.join(ckpt_dir, "fsg_vit_imagenette_demo.pth")
torch.save(model.state_dict(), ckpt_path)
print(f"πŸ’Ύ Checkpoint saved to: {ckpt_path}")