groffo commited on
Commit
8573586
·
0 Parent(s):

Initial commit of FSG-ViT

Browse files
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
.idea/ViT_with_FSG.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="Python 3.10 (cvpr)" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="PLAIN" />
10
+ <option name="myDocStringFormat" value="Plain" />
11
+ </component>
12
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="false" level="WARNING" enabled_by_default="false" />
5
+ </profile>
6
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (cvpr)" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/ViT_with_FSG.iml" filepath="$PROJECT_DIR$/.idea/ViT_with_FSG.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
+ </component>
6
+ </project>
README.md ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🔬 Feature Selection Gates (FSG) for Vision Transformers (ViT)
2
+
3
+ This repository provides a modular, extensible PyTorch implementation of **Feature Selection Gates (FSG)** with **Gradient Routing (GR)**, integrated into **Vision Transformers (ViTs)**. The approach is proposed in:
4
+
5
+ > **Feature Selection Gates with Gradient Routing for Endoscopic Image Computing**
6
+ > Giorgio Roffo, Carlo Biffi, Pietro Salvagnini, Andrea Cherubini
7
+ > Presented at MICCAI 2024
8
+ > 📄 [Paper](https://papers.miccai.org/miccai-2024/316-Paper0410.html) | 🧠 [arXiv](https://arxiv.org/abs/2407.04400) | 💻 [Code](https://github.com/cosmoimd/feature-selection-gates)
9
+
10
+ ---
11
+
12
+ ## 📌 What Is FSG?
13
+
14
+ **FSG** introduces **learnable gates** that sparsify transformer blocks by modulating residual connections, acting as **online feature selectors**. This process encourages **sparse connectivity**, which reduces overfitting and increases generalization — especially valuable in small and imbalanced datasets.
15
+
16
+ **Gradient Routing (GR)** enables dual-phase optimization:
17
+ - One optimizer updates FSG parameters
18
+ - A second optimizer updates the base model
19
+ This separation allows **task-specific tuning** and ensures stable learning.
20
+
21
+ ---
22
+
23
+ ## 💡 Why Use FSG?
24
+
25
+ ✅ **Plug & play**: Can be integrated into **any ViT architecture**
26
+ ✅ Works on **natural images**, **medical images**, and beyond
27
+ ✅ Can be adapted to **NLP Transformers** like GPTs and BERT
28
+ ✅ Lightweight and highly regularizing
29
+ ✅ Compatible with **multi-stream CNNs** and hybrid models
30
+
31
+ ⚠️ While our focus is on **endoscopic image computing**, the method has shown performance improvements on **CIFAR-100**, proving its applicability to **standard vision tasks**.
32
+
33
+ ---
34
+
35
+ ## 🧪 How to Use the FSG Wrapper
36
+
37
+ Use the `vit_with_fsg.py` script to augment a pretrained ViT from `torchvision`.
38
+
39
+ ```python
40
+ from torchvision.models import vit_b_16, ViT_B_16_Weights
41
+ from vit_with_fsg import vit_with_fsg
42
+ import torch
43
+
44
+ print("📥 Loading pretrained ViT_B_16...")
45
+ backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
46
+
47
+ print("🔧 Wrapping with Feature Selection Gates (FSG)...")
48
+ model = vit_with_fsg(vit_backbone=backbone)
49
+
50
+ print("🧪 Running dummy input...")
51
+ dummy_input = torch.randn(1, 3, 224, 224)
52
+ output = model(dummy_input)
53
+
54
+ print("✅ Done. Output shape:", output.shape)
55
+ ```
56
+
57
+ ---
58
+
59
+ ## 🚀 Demo Scripts
60
+
61
+ We provide full working training and inference examples:
62
+
63
+ | Dataset | Training Script | Inference Script | Checkpoint Path |
64
+ |-------------|-----------------------------|------------------------------|----------------------------------------------|
65
+ | MNIST | `demo_training_mnist.py` | `demo_inference_mnist.py` | `./checkpoints/fsg_vit_mnist_demo.pth` |
66
+ | Imagenette | `demo_training_imnet.py` | `demo_inference_imnet.py` | `./checkpoints/fsg_vit_imagenette_demo.pth` |
67
+
68
+ Each demo:
69
+ - Trains a ViT+B16 with FSG on a reduced dataset for speed.
70
+ - Uses separate learning rates for FSG and base model parameters.
71
+ - Includes GPU-aware prints and a training progress bar.
72
+ - Saves checkpoints for reproducible inference.
73
+
74
+ ### ▶️ Example Usage
75
+
76
+ ```bash
77
+ # Train on Imagenette
78
+ python demo_training_imnet.py
79
+
80
+ # Inference on Imagenette
81
+ python demo_inference_imnet.py --checkpoint ./checkpoints/fsg_vit_imagenette_demo.pth
82
+ ```
83
+
84
+ ```bash
85
+ # Train on MNIST
86
+ python demo_training_mnist.py
87
+
88
+ # Inference on MNIST
89
+ python demo_inference_mnist.py --checkpoint ./checkpoints/fsg_vit_mnist_demo.pth
90
+ ```
91
+
92
+ > ⚠️ These demos use reduced test sets and train for few iterations to make training quick. They're not meant for benchmarking, but rather for showcasing FSG integration.
93
+
94
+ ---
95
+
96
+ ## 🧠 Applicability Beyond Endoscopy
97
+
98
+ Although designed for **polyp size estimation in colonoscopy**, FSG is a **general mechanism** for:
99
+ - **Image classification**
100
+ - **Medical image analysis**
101
+ - **Multimodal fusion**
102
+ - **NLP Transformers** (e.g., GPTs, BERT) — apply FSG over token embeddings
103
+
104
+ We strongly encourage researchers to test FSG in **non-medical** domains.
105
+
106
+ ---
107
+
108
+ ## 📦 Files and Structure
109
+
110
+ ```
111
+ .
112
+ ├── vit_with_fsg.py # ViT + FSG wrapper
113
+ ├── demo_training_mnist.py
114
+ ├── demo_inference_mnist.py
115
+ ├── demo_training_imnet.py
116
+ ├── demo_inference_imnet.py
117
+ ├── checkpoints/ # Folder for .pth checkpoints
118
+ ```
119
+
120
+ ---
121
+
122
+ ## 📚 Citation
123
+
124
+ Please cite our work if you use this repository:
125
+
126
+ ```bibtex
127
+ @inproceedings{roffo2024FSG,
128
+ title={Feature Selection Gates with Gradient Routing for Endoscopic Image Computing},
129
+ author={Giorgio Roffo and Carlo Biffi and Pietro Salvagnini and Andrea Cherubini},
130
+ booktitle={MICCAI 2024, the 27th International Conference on Medical Image Computing and Computer Assisted Intervention, Marrakech, Morocco, October 2024.},
131
+ year={2024},
132
+ organization={Springer}
133
+ }
134
+ ```
135
+
136
+ ---
137
+
138
+ ## 📬 Contact
139
+
140
+ Lead Author: **Giorgio Roffo**
141
+ 📧 giorgio.roffo@gmail.com
142
+ 🏢 Cosmo Intelligent Medical Devices (IMD), Lainate, Italy
143
+
144
+ For more: [github.com/cosmoimd/feature-selection-gates](https://github.com/cosmoimd/feature-selection-gates)
demo_inference_imnet.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Demo script for applying Feature Selection Gates (FSG) to torchvision Vision Transformers
3
+ and running inference on the ImageNet-mini (Imagenette) validation set.
4
+
5
+ Each image is resized to 224x224 and has 3 RGB channels to be compatible with ViT.
6
+
7
+ Usage:
8
+
9
+ demo_inference_imnet.py --checkpoint ./checkpoints/fsg_vit_imagenette_demo.pth
10
+
11
+ Paper:
12
+ https://papers.miccai.org/miccai-2024/316-Paper0410.html
13
+ Code:
14
+ https://github.com/cosmoimd/feature-selection-gates
15
+ Contact:
16
+ giorgio.roffo@gmail.com
17
+ '''
18
+
19
+ import warnings
20
+ warnings.filterwarnings("ignore")
21
+
22
+ import os
23
+ import sys
24
+ import tarfile
25
+ import urllib.request
26
+ import torch
27
+ import psutil
28
+ from torchvision.models import vit_b_16, ViT_B_16_Weights
29
+ from vit_with_fsg import vit_with_fsg
30
+ from torchvision import transforms
31
+ from torchvision.datasets import ImageFolder
32
+ from torch.utils.data import DataLoader
33
+ import torch.nn.functional as F
34
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
35
+ from tqdm import tqdm
36
+
37
+ import argparse
38
+
39
+ parser = argparse.ArgumentParser(description="FSG-ViT inference on Imagenette")
40
+ parser.add_argument("--checkpoint", type=str, default=None, help="Path to .pth file of trained FSG-ViT model")
41
+ args = parser.parse_args()
42
+
43
+ if __name__ == "__main__":
44
+ warnings.filterwarnings("ignore", message="Failed to load image Python extension*")
45
+ wrn = False
46
+ print(f"\n📌 To run this script:\n"
47
+ f" ▶ Without checkpoint: python {os.path.basename(__file__)}\n"
48
+ f" ▶ With checkpoint: python {os.path.basename(__file__)} --checkpoint path/to/model.pth\n")
49
+
50
+ # Device and system info
51
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+ print(f"\n🖥️ Using device: {device}")
53
+ if device.type == "cuda":
54
+ print(f"🚀 CUDA device: {torch.cuda.get_device_name(0)}")
55
+ print(f"💾 GPU memory total: {torch.cuda.get_device_properties(0).total_memory / (1024 ** 3):.2f} GB")
56
+ print(f"🧠 System RAM: {psutil.virtual_memory().total / (1024 ** 3):.2f} GB")
57
+
58
+ print("\n📥 Loading pretrained ViT backbone from torchvision...")
59
+ backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
60
+
61
+ print("🔧 Wrapping with Feature Selection Gates (FSG)...")
62
+ model = vit_with_fsg(backbone).to(device)
63
+
64
+ if args.checkpoint is not None:
65
+ print(f"📂 Loading model weights from: {args.checkpoint}")
66
+ model.load_state_dict(torch.load(args.checkpoint, map_location=device))
67
+ else:
68
+ wrn = True
69
+ print("\n⚠️ No checkpoint provided. Evaluating randomly initialized model! 🧪\n")
70
+ print("❗ Note: The model has not been trained. Results will reflect a randomly initialized backbone.")
71
+
72
+ model.eval()
73
+
74
+ print("📚 Loading Imagenette validation set (224x224 RGB)...")
75
+ imagenette_path = "./imagenette2-160/val"
76
+ if not os.path.exists(imagenette_path):
77
+ print("📦 Downloading Imagenette...")
78
+ url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz"
79
+ tgz_path = "imagenette2-160.tgz"
80
+ urllib.request.urlretrieve(url, tgz_path)
81
+ print("📂 Extracting Imagenette dataset...")
82
+ with tarfile.open(tgz_path, "r:gz") as tar:
83
+ tar.extractall()
84
+ os.remove(tgz_path)
85
+ print("✅ Dataset ready.")
86
+
87
+ transform = transforms.Compose([
88
+ transforms.Resize((224, 224)),
89
+ transforms.ToTensor(),
90
+ transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
91
+ ])
92
+
93
+ dataset = ImageFolder(root=imagenette_path, transform=transform)
94
+ dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
95
+
96
+ y_true = []
97
+ y_pred = []
98
+
99
+ print("🧪 Running inference on Imagenette validation set using FSG-ViT-B-16 (code by G. Roffo)...\n\n")
100
+ with torch.no_grad():
101
+ for images, labels in tqdm(dataloader, desc="🔍 Inference progress", ncols=100):
102
+ images = images.to(device)
103
+ labels = labels.to(device)
104
+ outputs = model(images)
105
+ preds = torch.argmax(F.softmax(outputs, dim=1), dim=1)
106
+ y_true.extend(labels.cpu().tolist())
107
+ y_pred.extend(preds.cpu().tolist())
108
+
109
+ print("✅ Inference completed.")
110
+
111
+ acc = accuracy_score(y_true, y_pred)
112
+ prec = precision_score(y_true, y_pred, average='macro', zero_division=0)
113
+ rec = recall_score(y_true, y_pred, average='macro', zero_division=0)
114
+ f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
115
+
116
+ if wrn == True:
117
+ print("\n⚠️ No checkpoint provided. Evaluated randomly initialized model! 🧪\n")
118
+ print(f"\n📌 To run this script:\n"
119
+ f" ▶ With checkpoint: python {os.path.basename(__file__)} --checkpoint path/to/model.pth\n")
120
+
121
+ print(f"📊 Accuracy: {acc * 100:.2f}%")
122
+ print(f"📊 Precision: {prec * 100:.2f}%")
123
+ print(f"📊 Recall: {rec * 100:.2f}%")
124
+ print(f"📊 F1 Score: {f1 * 100:.2f}%")
demo_inference_mnist.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Demo script for applying Feature Selection Gates (FSG) to torchvision Vision Transformers
3
+ and running inference on the MNIST test set.
4
+
5
+ Each MNIST image is resized to 224x224 and converted to 3 channels to be compatible with ViT.
6
+
7
+ Usage:
8
+
9
+ demo_inference_mnist.py --checkpoint ./checkpoints/fsg_vit_mnist_demo.pth
10
+
11
+ Paper:
12
+ https://papers.miccai.org/miccai-2024/316-Paper0410.html
13
+ Code:
14
+ https://github.com/cosmoimd/feature-selection-gates
15
+ Contact:
16
+ giorgio.roffo@gmail.com
17
+ '''
18
+
19
+ import torch
20
+ import psutil
21
+ import argparse
22
+ import warnings
23
+ from torchvision.models import vit_b_16, ViT_B_16_Weights
24
+ from vit_with_fsg import vit_with_fsg
25
+ from torchvision.datasets import MNIST
26
+ from torchvision import transforms
27
+ from torch.utils.data import DataLoader
28
+ import torch.nn.functional as F
29
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
30
+ from tqdm import tqdm
31
+ import os
32
+
33
+ warnings.filterwarnings("ignore")
34
+
35
+ parser = argparse.ArgumentParser(description="FSG-ViT inference on MNIST")
36
+ parser.add_argument("--checkpoint", type=str, default=None, help="Path to .pth file of trained FSG-ViT model")
37
+ args = parser.parse_args()
38
+
39
+ if __name__ == "__main__":
40
+ warnings.filterwarnings("ignore", message="Failed to load image Python extension*")
41
+ wrn = False
42
+ print(f"\n📌 To run this script:\n"
43
+ f" ▶ Without checkpoint: python {os.path.basename(__file__)}\n"
44
+ f" ▶ With checkpoint: python {os.path.basename(__file__)} --checkpoint path/to/model.pth\n")
45
+
46
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+ print(f"\n🖥️ Using device: {device}")
48
+ if device.type == "cuda":
49
+ print(f"🚀 CUDA device: {torch.cuda.get_device_name(0)}")
50
+ print(f"💾 GPU memory total: {torch.cuda.get_device_properties(0).total_memory / (1024 ** 3):.2f} GB")
51
+ print(f"🧠 System RAM: {psutil.virtual_memory().total / (1024 ** 3):.2f} GB")
52
+
53
+ print("\n📥 Loading pretrained ViT backbone from torchvision...")
54
+ backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
55
+
56
+ print("🔧 Wrapping with Feature Selection Gates (FSG)...")
57
+ model = vit_with_fsg(backbone).to(device)
58
+
59
+ if args.checkpoint is not None:
60
+ print(f"📂 Loading model weights from: {args.checkpoint}")
61
+ model.load_state_dict(torch.load(args.checkpoint, map_location=device))
62
+ else:
63
+ wrn = True
64
+ print("\n⚠️ No checkpoint provided. Evaluating randomly initialized model! 🧪\n")
65
+ print("❗ Note: The model has not been trained. Results will reflect a randomly initialized backbone.")
66
+
67
+ model.eval()
68
+
69
+ print("📚 Loading MNIST test set (resized to 224x224, 3-channel)...")
70
+ transform = transforms.Compose([
71
+ transforms.Resize((224, 224)),
72
+ transforms.Grayscale(num_output_channels=3),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize((0.5,), (0.5,))
75
+ ])
76
+
77
+ test_dataset = MNIST(root="./data", train=False, download=True, transform=transform)
78
+ test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
79
+
80
+ y_true = []
81
+ y_pred = []
82
+
83
+ print("🧪 Running inference on MNIST test set using FSG-ViT-B-16 (code by G. Roffo)...")
84
+ with torch.no_grad():
85
+ for images, labels in tqdm(test_loader, desc="🔍 Inference progress", ncols=100):
86
+ images = images.to(device)
87
+ labels = labels.to(device)
88
+ outputs = model(images)
89
+ preds = torch.argmax(F.softmax(outputs, dim=1), dim=1)
90
+ y_true.extend(labels.cpu().tolist())
91
+ y_pred.extend(preds.cpu().tolist())
92
+
93
+ print("✅ Inference completed.")
94
+
95
+ acc = accuracy_score(y_true, y_pred)
96
+ prec = precision_score(y_true, y_pred, average='macro', zero_division=0)
97
+ rec = recall_score(y_true, y_pred, average='macro', zero_division=0)
98
+ f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
99
+
100
+ if wrn == True:
101
+ print("\n⚠️ No checkpoint provided. Evaluated randomly initialized model! 🧪\n")
102
+ print(f"\n📌 To run this script:\n"
103
+ f" ▶ With checkpoint: python {os.path.basename(__file__)} --checkpoint path/to/model.pth\n")
104
+
105
+ print(f"📊 Accuracy: {acc * 100:.2f}%")
106
+ print(f"📊 Precision: {prec * 100:.2f}%")
107
+ print(f"📊 Recall: {rec * 100:.2f}%")
108
+ print(f"📊 F1 Score: {f1 * 100:.2f}%")
demo_training_imnet.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Demo training script for Feature Selection Gates (FSG) with ViT on Imagenette
3
+
4
+ This script loads the Imagenette dataset (ImageNet-mini),
5
+ trains a ViT model augmented with FSG, and saves the model checkpoint.
6
+
7
+ Paper:
8
+ https://papers.miccai.org/miccai-2024/316-Paper0410.html
9
+ Code:
10
+ https://github.com/cosmoimd/feature-selection-gates
11
+ Contact:
12
+ giorgio.roffo@gmail.com
13
+ '''
14
+
15
+ import os
16
+ import tarfile
17
+ import urllib.request
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.optim as optim
21
+ import psutil
22
+ from tqdm import tqdm
23
+ from torchvision import transforms
24
+ from torchvision.models import vit_b_16, ViT_B_16_Weights
25
+ from torchvision.datasets import ImageFolder
26
+ from torch.utils.data import DataLoader
27
+ from vit_with_fsg import vit_with_fsg
28
+
29
+ # System info
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ print(f"\n🖥️ Using device: {device}")
32
+ if device.type == "cuda":
33
+ print(f"🚀 CUDA device: {torch.cuda.get_device_name(0)}")
34
+ print(f"💾 GPU memory total: {torch.cuda.get_device_properties(0).total_memory / (1024 ** 3):.2f} GB")
35
+ print(f"🧠 System RAM: {psutil.virtual_memory().total / (1024 ** 3):.2f} GB")
36
+
37
+ # Dataset path
38
+ imagenette_path = "./imagenette2-160/val"
39
+ if not os.path.exists(imagenette_path):
40
+ print("📦 Downloading Imagenette...")
41
+ url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz"
42
+ tgz_path = "imagenette2-160.tgz"
43
+ urllib.request.urlretrieve(url, tgz_path)
44
+ print("📂 Extracting Imagenette dataset...")
45
+ with tarfile.open(tgz_path, "r:gz") as tar:
46
+ tar.extractall()
47
+ os.remove(tgz_path)
48
+ print("✅ Dataset ready.")
49
+
50
+ # Transforms
51
+ transform = transforms.Compose([
52
+ transforms.Resize((224, 224)),
53
+ transforms.ToTensor(),
54
+ transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
55
+ ])
56
+
57
+ # Dataset and loader
58
+ dataset = ImageFolder(root=imagenette_path, transform=transform)
59
+ dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
60
+
61
+ # Model setup
62
+ print("\n📥 Loading pretrained ViT backbone from torchvision...")
63
+ backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
64
+ model = vit_with_fsg(backbone).to(device)
65
+
66
+ # Optimizer with separate LRs for FSG and base ViT
67
+ fsg_params, base_params = [], []
68
+ for name, param in model.named_parameters():
69
+ if 'fsag_rgb_ls' in name:
70
+ fsg_params.append(param)
71
+ else:
72
+ base_params.append(param)
73
+
74
+ lr_base = 1e-4
75
+ lr_fsg = 5e-4
76
+ print(f"\n🔧 Optimizer setup:")
77
+ print(f" 🔹 Base ViT parameters LR: {lr_base}")
78
+ print(f" 🔸 FSG parameters LR: {lr_fsg}")
79
+
80
+ optimizer = optim.AdamW([
81
+ {"params": base_params, "lr": lr_base},
82
+ {"params": fsg_params, "lr": lr_fsg}
83
+ ])
84
+ criterion = nn.CrossEntropyLoss()
85
+
86
+ # Training loop
87
+ epochs = 3
88
+ print(f"\n🚀 Starting demo training for {epochs} epochs...")
89
+ model.train()
90
+ for epoch in range(epochs):
91
+ steps_demo = 0 # to remove: for demo only
92
+ running_loss = 0.0
93
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", ncols=100)
94
+ for inputs, targets in pbar:
95
+ if steps_demo > 25: # to remove: for demo only
96
+ break # to remove: for demo only
97
+ steps_demo += 1 # to remove: for demo only
98
+ inputs, targets = inputs.to(device), targets.to(device)
99
+ optimizer.zero_grad()
100
+ outputs = model(inputs)
101
+ loss = criterion(outputs, targets)
102
+ loss.backward()
103
+ optimizer.step()
104
+ running_loss += loss.item()
105
+ pbar.set_postfix({"loss": running_loss / (pbar.n + 1e-8)})
106
+
107
+ print("\n✅ Training complete.")
108
+
109
+ # Save checkpoint
110
+ ckpt_dir = "./checkpoints"
111
+ os.makedirs(ckpt_dir, exist_ok=True)
112
+ ckpt_path = os.path.join(ckpt_dir, "fsg_vit_imagenette_demo.pth")
113
+ torch.save(model.state_dict(), ckpt_path)
114
+ print(f"💾 Checkpoint saved to: {ckpt_path}")
demo_training_mnist.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Demo training script for Feature Selection Gates (FSG) with ViT on MNIST test set
3
+
4
+ This is a minimal demo: we train only on the MNIST test set (resized and converted to 3-channel)
5
+ for a few epochs to simulate training, save the checkpoint, and allow downstream inference.
6
+
7
+ Paper:
8
+ https://papers.miccai.org/miccai-2024/316-Paper0410.html
9
+ Code:
10
+ https://github.com/cosmoimd/feature-selection-gates
11
+ Contact:
12
+ giorgio.roffo@gmail.com
13
+ '''
14
+
15
+ import os
16
+ import warnings
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.optim as optim
20
+ import psutil
21
+ from tqdm import tqdm
22
+ from torchvision import transforms
23
+ from torchvision.datasets import MNIST
24
+ from torchvision.models import vit_b_16, ViT_B_16_Weights
25
+ from torch.utils.data import DataLoader
26
+ from vit_with_fsg import vit_with_fsg
27
+
28
+ warnings.filterwarnings("ignore")
29
+
30
+ # Device info
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ print(f"\n🖥️ Using device: {device}")
33
+ if device.type == "cuda":
34
+ print(f"🚀 CUDA device: {torch.cuda.get_device_name(0)}")
35
+ print(f"💾 GPU memory total: {torch.cuda.get_device_properties(0).total_memory / (1024 ** 3):.2f} GB")
36
+ print(f"🧠 System RAM: {psutil.virtual_memory().total / (1024 ** 3):.2f} GB")
37
+
38
+ # Dataset loading
39
+ print("\n📚 Loading MNIST demo set for demo training (resized to 224x224, 3-channel)...")
40
+ transform = transforms.Compose([
41
+ transforms.Resize((224, 224)),
42
+ transforms.Grayscale(num_output_channels=3),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize((0.5,), (0.5,))
45
+ ])
46
+
47
+ dataset = MNIST(root="./data", train=False, download=True, transform=transform)
48
+ dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
49
+
50
+ # Load ViT backbone and wrap with FSG
51
+ print("\n📥 Loading pretrained ViT backbone from torchvision...")
52
+ backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
53
+ model = vit_with_fsg(backbone).to(device)
54
+
55
+ # Prepare optimizer with different LRs for FSG parameters and base model
56
+ fsg_params = []
57
+ base_params = []
58
+ for name, param in model.named_parameters():
59
+ if 'fsag_rgb_ls' in name:
60
+ fsg_params.append(param)
61
+ else:
62
+ base_params.append(param)
63
+
64
+ # Assign a higher LR to FSG parameters, lower to base ViT params
65
+ lr_base = 1e-4
66
+ lr_fsg = 5e-4
67
+ print(f"\n🔧 Optimizer setup:")
68
+ print(f" 🔹 Base ViT parameters LR: {lr_base}")
69
+ print(f" 🔸 FSG parameters LR: {lr_fsg}")
70
+
71
+ optimizer = optim.AdamW([
72
+ {"params": base_params, "lr": lr_base},
73
+ {"params": fsg_params, "lr": lr_fsg}
74
+ ])
75
+
76
+ criterion = nn.CrossEntropyLoss()
77
+ epochs = 3
78
+ print(f"\n🚀 Starting demo training for {epochs} epochs...")
79
+
80
+ model.train()
81
+ for epoch in range(epochs):
82
+ steps_demo = 0 # to remove: for demo only
83
+ running_loss = 0.0
84
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", ncols=100)
85
+ for inputs, targets in pbar:
86
+ if steps_demo > 25: # to remove: for demo only
87
+ break # to remove: for demo only
88
+ steps_demo += 1 # to remove: for demo only
89
+ inputs, targets = inputs.to(device), targets.to(device)
90
+ optimizer.zero_grad()
91
+ outputs = model(inputs)
92
+ loss = criterion(outputs, targets)
93
+ loss.backward()
94
+ optimizer.step()
95
+
96
+ running_loss += loss.item()
97
+ pbar.set_postfix({"loss": running_loss / (pbar.n + 1e-8)})
98
+
99
+ print("\n✅ Training complete.")
100
+
101
+ # Save checkpoint
102
+ ckpt_dir = "./checkpoints"
103
+ os.makedirs(ckpt_dir, exist_ok=True)
104
+ ckpt_path = os.path.join(ckpt_dir, "fsg_vit_mnist_demo.pth")
105
+ torch.save(model.state_dict(), ckpt_path)
106
+ print(f"💾 Checkpoint saved to: {ckpt_path}")
vit_with_fsg.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ ViTwithFSG: Vision Transformer wrapper with Feature Selection Gates (FSG)
3
+
4
+ This script defines a wrapper class to apply Feature Selection Gates (FSG) to a Vision Transformer (ViT) model.
5
+ FSG enhances model generalization by introducing sparse, learnable gates on the residual paths of attention and MLP blocks.
6
+ It is a form of architectural regularization designed for vision tasks and applicable to NLP tasks.
7
+
8
+ The method is introduced in:
9
+
10
+ @inproceedings{roffo2024FSG,
11
+ title={Feature Selection Gates with Gradient Routing for Endoscopic Image Computing},
12
+ author={Giorgio Roffo and Carlo Biffi and Pietro Salvagnini and Andrea Cherubini},
13
+ booktitle={MICCAI 2024, the 27th International Conference on Medical Image Computing and Computer Assisted Intervention, Marrakech, Morocco, October 2024.},
14
+ year={2024},
15
+ organization={Springer}
16
+ }
17
+
18
+ - Publication: https://papers.miccai.org/miccai-2024/316-Paper0410.html
19
+ - Code: https://github.com/cosmoimd/feature-selection-gates
20
+ - Contact: giorgio.roffo@gmail.com
21
+ - Affiliation: Cosmo Intelligent Medical Devices (IMD), Lainate, Italy
22
+ '''
23
+
24
+ # imports
25
+ import warnings
26
+ warnings.filterwarnings("ignore")
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ from torchvision.models.vision_transformer import VisionTransformer
31
+
32
+ class FSGBlock(nn.Module):
33
+ """
34
+ A Transformer encoder block augmented with Feature Selection Gates (FSG).
35
+ Each residual path (attention and MLP) is weighted element-wise by a learnable sigmoid gate.
36
+ This promotes sparse activation and serves as a regularization mechanism to avoid overfitting.
37
+ """
38
+ def __init__(self, original_block):
39
+ super().__init__()
40
+ self.self_attention = original_block.self_attention # Multi-head self-attention module
41
+ self.mlp = original_block.mlp # Feedforward network (2-layer MLP)
42
+ self.ln_1 = original_block.ln_1 # LayerNorm before attention
43
+ self.ln_2 = original_block.ln_2 # LayerNorm before MLP
44
+ self.dropout = original_block.dropout # Dropout after attention
45
+
46
+ dim = self.ln_1.normalized_shape[0] # Dimensionality of the model
47
+
48
+ # FSG: learnable gates (one per channel), initialized with Xavier normal
49
+ self.fsg_rectifier = nn.Sigmoid()
50
+ self.fsg_rgb_ls1 = nn.Parameter(torch.empty(dim)) # Gate for attention path
51
+ self.fsg_rgb_ls2 = nn.Parameter(torch.empty(dim)) # Gate for MLP path
52
+ nn.init.xavier_normal_(self.fsg_rgb_ls1.unsqueeze(0), gain=nn.init.calculate_gain('sigmoid'))
53
+ nn.init.xavier_normal_(self.fsg_rgb_ls2.unsqueeze(0), gain=nn.init.calculate_gain('sigmoid'))
54
+
55
+ def forward(self, x):
56
+ # Self-attention + gate
57
+ x_norm = self.ln_1(x)
58
+ attn_output, _ = self.self_attention(x_norm, x_norm, x_norm, need_weights=False)
59
+ attn_output = self.dropout(attn_output)
60
+ fsg_scores_1 = self.fsg_rectifier(self.fsg_rgb_ls1)
61
+ x = x + attn_output * fsg_scores_1 # Residual connection weighted by gate
62
+
63
+ # MLP + gate
64
+ x_norm = self.ln_2(x)
65
+ mlp_output = self.mlp(x_norm)
66
+ fsg_scores_2 = self.fsg_rectifier(self.fsg_rgb_ls2)
67
+ x = x + mlp_output * fsg_scores_2 # Residual connection weighted by gate
68
+
69
+ return x
70
+
71
+ class ViTwithFSG(nn.Module):
72
+ """
73
+ Wrapper module that injects FSGBlocks into each Transformer encoder block of a given ViT model.
74
+ """
75
+ def __init__(self, vit_backbone: VisionTransformer):
76
+ super().__init__()
77
+ self.vit = vit_backbone
78
+ for i, blk in enumerate(self.vit.encoder.layers):
79
+ self.vit.encoder.layers[i] = FSGBlock(blk) # Replace original block with FSGBlock
80
+
81
+ def forward(self, x):
82
+ return self.vit(x)
83
+
84
+ def vit_with_fsg(vit_backbone: VisionTransformer):
85
+ """
86
+ Factory function that wraps a torchvision VisionTransformer with FSG-enhanced encoder blocks.
87
+ """
88
+ return ViTwithFSG(vit_backbone)
89
+
90
+
91
+ # === Example Usage ===
92
+ if __name__ == "__main__":
93
+ import warnings
94
+ warnings.filterwarnings("ignore", message="Failed to load image Python extension*")
95
+
96
+ from torchvision.models import vit_b_16, ViT_B_16_Weights
97
+
98
+ print("\n📥 Loading pretrained ViT_B_16 backbone from torchvision...")
99
+ backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
100
+
101
+ print("🔧 Wrapping with Feature Selection Gates (FSG)...")
102
+ model = vit_with_fsg(vit_backbone=backbone)
103
+
104
+ print("🧪 Running dummy input through FSG-augmented ViT...")
105
+ dummy_input = torch.randn(1, 3, 224, 224)
106
+ output = model(dummy_input)
107
+
108
+ print("✅ Inference completed.")
109
+ print("📐 Output shape:", output.shape)