File size: 4,910 Bytes
8573586
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
'''
Demo script for applying Feature Selection Gates (FSG) to torchvision Vision Transformers
and running inference on the ImageNet-mini (Imagenette) validation set.

Each image is resized to 224x224 and has 3 RGB channels to be compatible with ViT.

Usage:

demo_inference_imnet.py --checkpoint ./checkpoints/fsg_vit_imagenette_demo.pth

Paper:
https://papers.miccai.org/miccai-2024/316-Paper0410.html
Code:
https://github.com/cosmoimd/feature-selection-gates
Contact:
giorgio.roffo@gmail.com
'''

import warnings
warnings.filterwarnings("ignore")

import os
import sys
import tarfile
import urllib.request
import torch
import psutil
from torchvision.models import vit_b_16, ViT_B_16_Weights
from vit_with_fsg import vit_with_fsg
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm

import argparse

parser = argparse.ArgumentParser(description="FSG-ViT inference on Imagenette")
parser.add_argument("--checkpoint", type=str, default=None, help="Path to .pth file of trained FSG-ViT model")
args = parser.parse_args()

if __name__ == "__main__":
    warnings.filterwarnings("ignore", message="Failed to load image Python extension*")
    wrn = False
    print(f"\nπŸ“Œ To run this script:\n"
          f"   β–Ά Without checkpoint: python {os.path.basename(__file__)}\n"
          f"   β–Ά With checkpoint:    python {os.path.basename(__file__)} --checkpoint path/to/model.pth\n")

    # Device and system info
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\nπŸ–₯️  Using device: {device}")
    if device.type == "cuda":
        print(f"πŸš€ CUDA device: {torch.cuda.get_device_name(0)}")
        print(f"πŸ’Ύ GPU memory total: {torch.cuda.get_device_properties(0).total_memory / (1024 ** 3):.2f} GB")
    print(f"🧠 System RAM: {psutil.virtual_memory().total / (1024 ** 3):.2f} GB")

    print("\nπŸ“₯ Loading pretrained ViT backbone from torchvision...")
    backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)

    print("πŸ”§ Wrapping with Feature Selection Gates (FSG)...")
    model = vit_with_fsg(backbone).to(device)

    if args.checkpoint is not None:
        print(f"πŸ“‚ Loading model weights from: {args.checkpoint}")
        model.load_state_dict(torch.load(args.checkpoint, map_location=device))
    else:
        wrn = True
        print("\n⚠️  No checkpoint provided. Evaluating randomly initialized model! πŸ§ͺ\n")
        print("❗ Note: The model has not been trained. Results will reflect a randomly initialized backbone.")

    model.eval()

    print("πŸ“š Loading Imagenette validation set (224x224 RGB)...")
    imagenette_path = "./imagenette2-160/val"
    if not os.path.exists(imagenette_path):
        print("πŸ“¦ Downloading Imagenette...")
        url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz"
        tgz_path = "imagenette2-160.tgz"
        urllib.request.urlretrieve(url, tgz_path)
        print("πŸ“‚ Extracting Imagenette dataset...")
        with tarfile.open(tgz_path, "r:gz") as tar:
            tar.extractall()
        os.remove(tgz_path)
        print("βœ… Dataset ready.")

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
    ])

    dataset = ImageFolder(root=imagenette_path, transform=transform)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

    y_true = []
    y_pred = []

    print("πŸ§ͺ Running inference on Imagenette validation set using FSG-ViT-B-16 (code by G. Roffo)...\n\n")
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="πŸ” Inference progress", ncols=100):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            preds = torch.argmax(F.softmax(outputs, dim=1), dim=1)
            y_true.extend(labels.cpu().tolist())
            y_pred.extend(preds.cpu().tolist())

    print("βœ… Inference completed.")

    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average='macro', zero_division=0)
    rec = recall_score(y_true, y_pred, average='macro', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)

    if wrn == True:
        print("\n⚠️  No checkpoint provided. Evaluated randomly initialized model! πŸ§ͺ\n")
        print(f"\nπŸ“Œ To run this script:\n"
              f"   β–Ά With checkpoint:    python {os.path.basename(__file__)} --checkpoint path/to/model.pth\n")

    print(f"πŸ“Š Accuracy:  {acc * 100:.2f}%")
    print(f"πŸ“Š Precision: {prec * 100:.2f}%")
    print(f"πŸ“Š Recall:    {rec * 100:.2f}%")
    print(f"πŸ“Š F1 Score:  {f1 * 100:.2f}%")