groffo
commited on
Commit
·
8573586
0
Parent(s):
Initial commit of FSG-ViT
Browse files- .idea/.gitignore +3 -0
- .idea/ViT_with_FSG.iml +12 -0
- .idea/inspectionProfiles/Project_Default.xml +6 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +4 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- README.md +144 -0
- demo_inference_imnet.py +124 -0
- demo_inference_mnist.py +108 -0
- demo_training_imnet.py +114 -0
- demo_training_mnist.py +106 -0
- vit_with_fsg.py +109 -0
.idea/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
.idea/ViT_with_FSG.iml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$" />
|
5 |
+
<orderEntry type="jdk" jdkName="Python 3.10 (cvpr)" jdkType="Python SDK" />
|
6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
+
</component>
|
8 |
+
<component name="PyDocumentationSettings">
|
9 |
+
<option name="format" value="PLAIN" />
|
10 |
+
<option name="myDocStringFormat" value="Plain" />
|
11 |
+
</component>
|
12 |
+
</module>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<profile version="1.0">
|
3 |
+
<option name="myName" value="Project Default" />
|
4 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="false" level="WARNING" enabled_by_default="false" />
|
5 |
+
</profile>
|
6 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/misc.xml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (cvpr)" project-jdk-type="Python SDK" />
|
4 |
+
</project>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/ViT_with_FSG.iml" filepath="$PROJECT_DIR$/.idea/ViT_with_FSG.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
.idea/vcs.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
5 |
+
</component>
|
6 |
+
</project>
|
README.md
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🔬 Feature Selection Gates (FSG) for Vision Transformers (ViT)
|
2 |
+
|
3 |
+
This repository provides a modular, extensible PyTorch implementation of **Feature Selection Gates (FSG)** with **Gradient Routing (GR)**, integrated into **Vision Transformers (ViTs)**. The approach is proposed in:
|
4 |
+
|
5 |
+
> **Feature Selection Gates with Gradient Routing for Endoscopic Image Computing**
|
6 |
+
> Giorgio Roffo, Carlo Biffi, Pietro Salvagnini, Andrea Cherubini
|
7 |
+
> Presented at MICCAI 2024
|
8 |
+
> 📄 [Paper](https://papers.miccai.org/miccai-2024/316-Paper0410.html) | 🧠 [arXiv](https://arxiv.org/abs/2407.04400) | 💻 [Code](https://github.com/cosmoimd/feature-selection-gates)
|
9 |
+
|
10 |
+
---
|
11 |
+
|
12 |
+
## 📌 What Is FSG?
|
13 |
+
|
14 |
+
**FSG** introduces **learnable gates** that sparsify transformer blocks by modulating residual connections, acting as **online feature selectors**. This process encourages **sparse connectivity**, which reduces overfitting and increases generalization — especially valuable in small and imbalanced datasets.
|
15 |
+
|
16 |
+
**Gradient Routing (GR)** enables dual-phase optimization:
|
17 |
+
- One optimizer updates FSG parameters
|
18 |
+
- A second optimizer updates the base model
|
19 |
+
This separation allows **task-specific tuning** and ensures stable learning.
|
20 |
+
|
21 |
+
---
|
22 |
+
|
23 |
+
## 💡 Why Use FSG?
|
24 |
+
|
25 |
+
✅ **Plug & play**: Can be integrated into **any ViT architecture**
|
26 |
+
✅ Works on **natural images**, **medical images**, and beyond
|
27 |
+
✅ Can be adapted to **NLP Transformers** like GPTs and BERT
|
28 |
+
✅ Lightweight and highly regularizing
|
29 |
+
✅ Compatible with **multi-stream CNNs** and hybrid models
|
30 |
+
|
31 |
+
⚠️ While our focus is on **endoscopic image computing**, the method has shown performance improvements on **CIFAR-100**, proving its applicability to **standard vision tasks**.
|
32 |
+
|
33 |
+
---
|
34 |
+
|
35 |
+
## 🧪 How to Use the FSG Wrapper
|
36 |
+
|
37 |
+
Use the `vit_with_fsg.py` script to augment a pretrained ViT from `torchvision`.
|
38 |
+
|
39 |
+
```python
|
40 |
+
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
41 |
+
from vit_with_fsg import vit_with_fsg
|
42 |
+
import torch
|
43 |
+
|
44 |
+
print("📥 Loading pretrained ViT_B_16...")
|
45 |
+
backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
|
46 |
+
|
47 |
+
print("🔧 Wrapping with Feature Selection Gates (FSG)...")
|
48 |
+
model = vit_with_fsg(vit_backbone=backbone)
|
49 |
+
|
50 |
+
print("🧪 Running dummy input...")
|
51 |
+
dummy_input = torch.randn(1, 3, 224, 224)
|
52 |
+
output = model(dummy_input)
|
53 |
+
|
54 |
+
print("✅ Done. Output shape:", output.shape)
|
55 |
+
```
|
56 |
+
|
57 |
+
---
|
58 |
+
|
59 |
+
## 🚀 Demo Scripts
|
60 |
+
|
61 |
+
We provide full working training and inference examples:
|
62 |
+
|
63 |
+
| Dataset | Training Script | Inference Script | Checkpoint Path |
|
64 |
+
|-------------|-----------------------------|------------------------------|----------------------------------------------|
|
65 |
+
| MNIST | `demo_training_mnist.py` | `demo_inference_mnist.py` | `./checkpoints/fsg_vit_mnist_demo.pth` |
|
66 |
+
| Imagenette | `demo_training_imnet.py` | `demo_inference_imnet.py` | `./checkpoints/fsg_vit_imagenette_demo.pth` |
|
67 |
+
|
68 |
+
Each demo:
|
69 |
+
- Trains a ViT+B16 with FSG on a reduced dataset for speed.
|
70 |
+
- Uses separate learning rates for FSG and base model parameters.
|
71 |
+
- Includes GPU-aware prints and a training progress bar.
|
72 |
+
- Saves checkpoints for reproducible inference.
|
73 |
+
|
74 |
+
### ▶️ Example Usage
|
75 |
+
|
76 |
+
```bash
|
77 |
+
# Train on Imagenette
|
78 |
+
python demo_training_imnet.py
|
79 |
+
|
80 |
+
# Inference on Imagenette
|
81 |
+
python demo_inference_imnet.py --checkpoint ./checkpoints/fsg_vit_imagenette_demo.pth
|
82 |
+
```
|
83 |
+
|
84 |
+
```bash
|
85 |
+
# Train on MNIST
|
86 |
+
python demo_training_mnist.py
|
87 |
+
|
88 |
+
# Inference on MNIST
|
89 |
+
python demo_inference_mnist.py --checkpoint ./checkpoints/fsg_vit_mnist_demo.pth
|
90 |
+
```
|
91 |
+
|
92 |
+
> ⚠️ These demos use reduced test sets and train for few iterations to make training quick. They're not meant for benchmarking, but rather for showcasing FSG integration.
|
93 |
+
|
94 |
+
---
|
95 |
+
|
96 |
+
## 🧠 Applicability Beyond Endoscopy
|
97 |
+
|
98 |
+
Although designed for **polyp size estimation in colonoscopy**, FSG is a **general mechanism** for:
|
99 |
+
- **Image classification**
|
100 |
+
- **Medical image analysis**
|
101 |
+
- **Multimodal fusion**
|
102 |
+
- **NLP Transformers** (e.g., GPTs, BERT) — apply FSG over token embeddings
|
103 |
+
|
104 |
+
We strongly encourage researchers to test FSG in **non-medical** domains.
|
105 |
+
|
106 |
+
---
|
107 |
+
|
108 |
+
## 📦 Files and Structure
|
109 |
+
|
110 |
+
```
|
111 |
+
.
|
112 |
+
├── vit_with_fsg.py # ViT + FSG wrapper
|
113 |
+
├── demo_training_mnist.py
|
114 |
+
├── demo_inference_mnist.py
|
115 |
+
├── demo_training_imnet.py
|
116 |
+
├── demo_inference_imnet.py
|
117 |
+
├── checkpoints/ # Folder for .pth checkpoints
|
118 |
+
```
|
119 |
+
|
120 |
+
---
|
121 |
+
|
122 |
+
## 📚 Citation
|
123 |
+
|
124 |
+
Please cite our work if you use this repository:
|
125 |
+
|
126 |
+
```bibtex
|
127 |
+
@inproceedings{roffo2024FSG,
|
128 |
+
title={Feature Selection Gates with Gradient Routing for Endoscopic Image Computing},
|
129 |
+
author={Giorgio Roffo and Carlo Biffi and Pietro Salvagnini and Andrea Cherubini},
|
130 |
+
booktitle={MICCAI 2024, the 27th International Conference on Medical Image Computing and Computer Assisted Intervention, Marrakech, Morocco, October 2024.},
|
131 |
+
year={2024},
|
132 |
+
organization={Springer}
|
133 |
+
}
|
134 |
+
```
|
135 |
+
|
136 |
+
---
|
137 |
+
|
138 |
+
## 📬 Contact
|
139 |
+
|
140 |
+
Lead Author: **Giorgio Roffo**
|
141 |
+
📧 giorgio.roffo@gmail.com
|
142 |
+
🏢 Cosmo Intelligent Medical Devices (IMD), Lainate, Italy
|
143 |
+
|
144 |
+
For more: [github.com/cosmoimd/feature-selection-gates](https://github.com/cosmoimd/feature-selection-gates)
|
demo_inference_imnet.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Demo script for applying Feature Selection Gates (FSG) to torchvision Vision Transformers
|
3 |
+
and running inference on the ImageNet-mini (Imagenette) validation set.
|
4 |
+
|
5 |
+
Each image is resized to 224x224 and has 3 RGB channels to be compatible with ViT.
|
6 |
+
|
7 |
+
Usage:
|
8 |
+
|
9 |
+
demo_inference_imnet.py --checkpoint ./checkpoints/fsg_vit_imagenette_demo.pth
|
10 |
+
|
11 |
+
Paper:
|
12 |
+
https://papers.miccai.org/miccai-2024/316-Paper0410.html
|
13 |
+
Code:
|
14 |
+
https://github.com/cosmoimd/feature-selection-gates
|
15 |
+
Contact:
|
16 |
+
giorgio.roffo@gmail.com
|
17 |
+
'''
|
18 |
+
|
19 |
+
import warnings
|
20 |
+
warnings.filterwarnings("ignore")
|
21 |
+
|
22 |
+
import os
|
23 |
+
import sys
|
24 |
+
import tarfile
|
25 |
+
import urllib.request
|
26 |
+
import torch
|
27 |
+
import psutil
|
28 |
+
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
29 |
+
from vit_with_fsg import vit_with_fsg
|
30 |
+
from torchvision import transforms
|
31 |
+
from torchvision.datasets import ImageFolder
|
32 |
+
from torch.utils.data import DataLoader
|
33 |
+
import torch.nn.functional as F
|
34 |
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
35 |
+
from tqdm import tqdm
|
36 |
+
|
37 |
+
import argparse
|
38 |
+
|
39 |
+
parser = argparse.ArgumentParser(description="FSG-ViT inference on Imagenette")
|
40 |
+
parser.add_argument("--checkpoint", type=str, default=None, help="Path to .pth file of trained FSG-ViT model")
|
41 |
+
args = parser.parse_args()
|
42 |
+
|
43 |
+
if __name__ == "__main__":
|
44 |
+
warnings.filterwarnings("ignore", message="Failed to load image Python extension*")
|
45 |
+
wrn = False
|
46 |
+
print(f"\n📌 To run this script:\n"
|
47 |
+
f" ▶ Without checkpoint: python {os.path.basename(__file__)}\n"
|
48 |
+
f" ▶ With checkpoint: python {os.path.basename(__file__)} --checkpoint path/to/model.pth\n")
|
49 |
+
|
50 |
+
# Device and system info
|
51 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
52 |
+
print(f"\n🖥️ Using device: {device}")
|
53 |
+
if device.type == "cuda":
|
54 |
+
print(f"🚀 CUDA device: {torch.cuda.get_device_name(0)}")
|
55 |
+
print(f"💾 GPU memory total: {torch.cuda.get_device_properties(0).total_memory / (1024 ** 3):.2f} GB")
|
56 |
+
print(f"🧠 System RAM: {psutil.virtual_memory().total / (1024 ** 3):.2f} GB")
|
57 |
+
|
58 |
+
print("\n📥 Loading pretrained ViT backbone from torchvision...")
|
59 |
+
backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
|
60 |
+
|
61 |
+
print("🔧 Wrapping with Feature Selection Gates (FSG)...")
|
62 |
+
model = vit_with_fsg(backbone).to(device)
|
63 |
+
|
64 |
+
if args.checkpoint is not None:
|
65 |
+
print(f"📂 Loading model weights from: {args.checkpoint}")
|
66 |
+
model.load_state_dict(torch.load(args.checkpoint, map_location=device))
|
67 |
+
else:
|
68 |
+
wrn = True
|
69 |
+
print("\n⚠️ No checkpoint provided. Evaluating randomly initialized model! 🧪\n")
|
70 |
+
print("❗ Note: The model has not been trained. Results will reflect a randomly initialized backbone.")
|
71 |
+
|
72 |
+
model.eval()
|
73 |
+
|
74 |
+
print("📚 Loading Imagenette validation set (224x224 RGB)...")
|
75 |
+
imagenette_path = "./imagenette2-160/val"
|
76 |
+
if not os.path.exists(imagenette_path):
|
77 |
+
print("📦 Downloading Imagenette...")
|
78 |
+
url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz"
|
79 |
+
tgz_path = "imagenette2-160.tgz"
|
80 |
+
urllib.request.urlretrieve(url, tgz_path)
|
81 |
+
print("📂 Extracting Imagenette dataset...")
|
82 |
+
with tarfile.open(tgz_path, "r:gz") as tar:
|
83 |
+
tar.extractall()
|
84 |
+
os.remove(tgz_path)
|
85 |
+
print("✅ Dataset ready.")
|
86 |
+
|
87 |
+
transform = transforms.Compose([
|
88 |
+
transforms.Resize((224, 224)),
|
89 |
+
transforms.ToTensor(),
|
90 |
+
transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
|
91 |
+
])
|
92 |
+
|
93 |
+
dataset = ImageFolder(root=imagenette_path, transform=transform)
|
94 |
+
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
|
95 |
+
|
96 |
+
y_true = []
|
97 |
+
y_pred = []
|
98 |
+
|
99 |
+
print("🧪 Running inference on Imagenette validation set using FSG-ViT-B-16 (code by G. Roffo)...\n\n")
|
100 |
+
with torch.no_grad():
|
101 |
+
for images, labels in tqdm(dataloader, desc="🔍 Inference progress", ncols=100):
|
102 |
+
images = images.to(device)
|
103 |
+
labels = labels.to(device)
|
104 |
+
outputs = model(images)
|
105 |
+
preds = torch.argmax(F.softmax(outputs, dim=1), dim=1)
|
106 |
+
y_true.extend(labels.cpu().tolist())
|
107 |
+
y_pred.extend(preds.cpu().tolist())
|
108 |
+
|
109 |
+
print("✅ Inference completed.")
|
110 |
+
|
111 |
+
acc = accuracy_score(y_true, y_pred)
|
112 |
+
prec = precision_score(y_true, y_pred, average='macro', zero_division=0)
|
113 |
+
rec = recall_score(y_true, y_pred, average='macro', zero_division=0)
|
114 |
+
f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
|
115 |
+
|
116 |
+
if wrn == True:
|
117 |
+
print("\n⚠️ No checkpoint provided. Evaluated randomly initialized model! 🧪\n")
|
118 |
+
print(f"\n📌 To run this script:\n"
|
119 |
+
f" ▶ With checkpoint: python {os.path.basename(__file__)} --checkpoint path/to/model.pth\n")
|
120 |
+
|
121 |
+
print(f"📊 Accuracy: {acc * 100:.2f}%")
|
122 |
+
print(f"📊 Precision: {prec * 100:.2f}%")
|
123 |
+
print(f"📊 Recall: {rec * 100:.2f}%")
|
124 |
+
print(f"📊 F1 Score: {f1 * 100:.2f}%")
|
demo_inference_mnist.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Demo script for applying Feature Selection Gates (FSG) to torchvision Vision Transformers
|
3 |
+
and running inference on the MNIST test set.
|
4 |
+
|
5 |
+
Each MNIST image is resized to 224x224 and converted to 3 channels to be compatible with ViT.
|
6 |
+
|
7 |
+
Usage:
|
8 |
+
|
9 |
+
demo_inference_mnist.py --checkpoint ./checkpoints/fsg_vit_mnist_demo.pth
|
10 |
+
|
11 |
+
Paper:
|
12 |
+
https://papers.miccai.org/miccai-2024/316-Paper0410.html
|
13 |
+
Code:
|
14 |
+
https://github.com/cosmoimd/feature-selection-gates
|
15 |
+
Contact:
|
16 |
+
giorgio.roffo@gmail.com
|
17 |
+
'''
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import psutil
|
21 |
+
import argparse
|
22 |
+
import warnings
|
23 |
+
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
24 |
+
from vit_with_fsg import vit_with_fsg
|
25 |
+
from torchvision.datasets import MNIST
|
26 |
+
from torchvision import transforms
|
27 |
+
from torch.utils.data import DataLoader
|
28 |
+
import torch.nn.functional as F
|
29 |
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
30 |
+
from tqdm import tqdm
|
31 |
+
import os
|
32 |
+
|
33 |
+
warnings.filterwarnings("ignore")
|
34 |
+
|
35 |
+
parser = argparse.ArgumentParser(description="FSG-ViT inference on MNIST")
|
36 |
+
parser.add_argument("--checkpoint", type=str, default=None, help="Path to .pth file of trained FSG-ViT model")
|
37 |
+
args = parser.parse_args()
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
warnings.filterwarnings("ignore", message="Failed to load image Python extension*")
|
41 |
+
wrn = False
|
42 |
+
print(f"\n📌 To run this script:\n"
|
43 |
+
f" ▶ Without checkpoint: python {os.path.basename(__file__)}\n"
|
44 |
+
f" ▶ With checkpoint: python {os.path.basename(__file__)} --checkpoint path/to/model.pth\n")
|
45 |
+
|
46 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
47 |
+
print(f"\n🖥️ Using device: {device}")
|
48 |
+
if device.type == "cuda":
|
49 |
+
print(f"🚀 CUDA device: {torch.cuda.get_device_name(0)}")
|
50 |
+
print(f"💾 GPU memory total: {torch.cuda.get_device_properties(0).total_memory / (1024 ** 3):.2f} GB")
|
51 |
+
print(f"🧠 System RAM: {psutil.virtual_memory().total / (1024 ** 3):.2f} GB")
|
52 |
+
|
53 |
+
print("\n📥 Loading pretrained ViT backbone from torchvision...")
|
54 |
+
backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
|
55 |
+
|
56 |
+
print("🔧 Wrapping with Feature Selection Gates (FSG)...")
|
57 |
+
model = vit_with_fsg(backbone).to(device)
|
58 |
+
|
59 |
+
if args.checkpoint is not None:
|
60 |
+
print(f"📂 Loading model weights from: {args.checkpoint}")
|
61 |
+
model.load_state_dict(torch.load(args.checkpoint, map_location=device))
|
62 |
+
else:
|
63 |
+
wrn = True
|
64 |
+
print("\n⚠️ No checkpoint provided. Evaluating randomly initialized model! 🧪\n")
|
65 |
+
print("❗ Note: The model has not been trained. Results will reflect a randomly initialized backbone.")
|
66 |
+
|
67 |
+
model.eval()
|
68 |
+
|
69 |
+
print("📚 Loading MNIST test set (resized to 224x224, 3-channel)...")
|
70 |
+
transform = transforms.Compose([
|
71 |
+
transforms.Resize((224, 224)),
|
72 |
+
transforms.Grayscale(num_output_channels=3),
|
73 |
+
transforms.ToTensor(),
|
74 |
+
transforms.Normalize((0.5,), (0.5,))
|
75 |
+
])
|
76 |
+
|
77 |
+
test_dataset = MNIST(root="./data", train=False, download=True, transform=transform)
|
78 |
+
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
79 |
+
|
80 |
+
y_true = []
|
81 |
+
y_pred = []
|
82 |
+
|
83 |
+
print("🧪 Running inference on MNIST test set using FSG-ViT-B-16 (code by G. Roffo)...")
|
84 |
+
with torch.no_grad():
|
85 |
+
for images, labels in tqdm(test_loader, desc="🔍 Inference progress", ncols=100):
|
86 |
+
images = images.to(device)
|
87 |
+
labels = labels.to(device)
|
88 |
+
outputs = model(images)
|
89 |
+
preds = torch.argmax(F.softmax(outputs, dim=1), dim=1)
|
90 |
+
y_true.extend(labels.cpu().tolist())
|
91 |
+
y_pred.extend(preds.cpu().tolist())
|
92 |
+
|
93 |
+
print("✅ Inference completed.")
|
94 |
+
|
95 |
+
acc = accuracy_score(y_true, y_pred)
|
96 |
+
prec = precision_score(y_true, y_pred, average='macro', zero_division=0)
|
97 |
+
rec = recall_score(y_true, y_pred, average='macro', zero_division=0)
|
98 |
+
f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
|
99 |
+
|
100 |
+
if wrn == True:
|
101 |
+
print("\n⚠️ No checkpoint provided. Evaluated randomly initialized model! 🧪\n")
|
102 |
+
print(f"\n📌 To run this script:\n"
|
103 |
+
f" ▶ With checkpoint: python {os.path.basename(__file__)} --checkpoint path/to/model.pth\n")
|
104 |
+
|
105 |
+
print(f"📊 Accuracy: {acc * 100:.2f}%")
|
106 |
+
print(f"📊 Precision: {prec * 100:.2f}%")
|
107 |
+
print(f"📊 Recall: {rec * 100:.2f}%")
|
108 |
+
print(f"📊 F1 Score: {f1 * 100:.2f}%")
|
demo_training_imnet.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Demo training script for Feature Selection Gates (FSG) with ViT on Imagenette
|
3 |
+
|
4 |
+
This script loads the Imagenette dataset (ImageNet-mini),
|
5 |
+
trains a ViT model augmented with FSG, and saves the model checkpoint.
|
6 |
+
|
7 |
+
Paper:
|
8 |
+
https://papers.miccai.org/miccai-2024/316-Paper0410.html
|
9 |
+
Code:
|
10 |
+
https://github.com/cosmoimd/feature-selection-gates
|
11 |
+
Contact:
|
12 |
+
giorgio.roffo@gmail.com
|
13 |
+
'''
|
14 |
+
|
15 |
+
import os
|
16 |
+
import tarfile
|
17 |
+
import urllib.request
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.optim as optim
|
21 |
+
import psutil
|
22 |
+
from tqdm import tqdm
|
23 |
+
from torchvision import transforms
|
24 |
+
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
25 |
+
from torchvision.datasets import ImageFolder
|
26 |
+
from torch.utils.data import DataLoader
|
27 |
+
from vit_with_fsg import vit_with_fsg
|
28 |
+
|
29 |
+
# System info
|
30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
+
print(f"\n🖥️ Using device: {device}")
|
32 |
+
if device.type == "cuda":
|
33 |
+
print(f"🚀 CUDA device: {torch.cuda.get_device_name(0)}")
|
34 |
+
print(f"💾 GPU memory total: {torch.cuda.get_device_properties(0).total_memory / (1024 ** 3):.2f} GB")
|
35 |
+
print(f"🧠 System RAM: {psutil.virtual_memory().total / (1024 ** 3):.2f} GB")
|
36 |
+
|
37 |
+
# Dataset path
|
38 |
+
imagenette_path = "./imagenette2-160/val"
|
39 |
+
if not os.path.exists(imagenette_path):
|
40 |
+
print("📦 Downloading Imagenette...")
|
41 |
+
url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz"
|
42 |
+
tgz_path = "imagenette2-160.tgz"
|
43 |
+
urllib.request.urlretrieve(url, tgz_path)
|
44 |
+
print("📂 Extracting Imagenette dataset...")
|
45 |
+
with tarfile.open(tgz_path, "r:gz") as tar:
|
46 |
+
tar.extractall()
|
47 |
+
os.remove(tgz_path)
|
48 |
+
print("✅ Dataset ready.")
|
49 |
+
|
50 |
+
# Transforms
|
51 |
+
transform = transforms.Compose([
|
52 |
+
transforms.Resize((224, 224)),
|
53 |
+
transforms.ToTensor(),
|
54 |
+
transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
|
55 |
+
])
|
56 |
+
|
57 |
+
# Dataset and loader
|
58 |
+
dataset = ImageFolder(root=imagenette_path, transform=transform)
|
59 |
+
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
|
60 |
+
|
61 |
+
# Model setup
|
62 |
+
print("\n📥 Loading pretrained ViT backbone from torchvision...")
|
63 |
+
backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
|
64 |
+
model = vit_with_fsg(backbone).to(device)
|
65 |
+
|
66 |
+
# Optimizer with separate LRs for FSG and base ViT
|
67 |
+
fsg_params, base_params = [], []
|
68 |
+
for name, param in model.named_parameters():
|
69 |
+
if 'fsag_rgb_ls' in name:
|
70 |
+
fsg_params.append(param)
|
71 |
+
else:
|
72 |
+
base_params.append(param)
|
73 |
+
|
74 |
+
lr_base = 1e-4
|
75 |
+
lr_fsg = 5e-4
|
76 |
+
print(f"\n🔧 Optimizer setup:")
|
77 |
+
print(f" 🔹 Base ViT parameters LR: {lr_base}")
|
78 |
+
print(f" 🔸 FSG parameters LR: {lr_fsg}")
|
79 |
+
|
80 |
+
optimizer = optim.AdamW([
|
81 |
+
{"params": base_params, "lr": lr_base},
|
82 |
+
{"params": fsg_params, "lr": lr_fsg}
|
83 |
+
])
|
84 |
+
criterion = nn.CrossEntropyLoss()
|
85 |
+
|
86 |
+
# Training loop
|
87 |
+
epochs = 3
|
88 |
+
print(f"\n🚀 Starting demo training for {epochs} epochs...")
|
89 |
+
model.train()
|
90 |
+
for epoch in range(epochs):
|
91 |
+
steps_demo = 0 # to remove: for demo only
|
92 |
+
running_loss = 0.0
|
93 |
+
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", ncols=100)
|
94 |
+
for inputs, targets in pbar:
|
95 |
+
if steps_demo > 25: # to remove: for demo only
|
96 |
+
break # to remove: for demo only
|
97 |
+
steps_demo += 1 # to remove: for demo only
|
98 |
+
inputs, targets = inputs.to(device), targets.to(device)
|
99 |
+
optimizer.zero_grad()
|
100 |
+
outputs = model(inputs)
|
101 |
+
loss = criterion(outputs, targets)
|
102 |
+
loss.backward()
|
103 |
+
optimizer.step()
|
104 |
+
running_loss += loss.item()
|
105 |
+
pbar.set_postfix({"loss": running_loss / (pbar.n + 1e-8)})
|
106 |
+
|
107 |
+
print("\n✅ Training complete.")
|
108 |
+
|
109 |
+
# Save checkpoint
|
110 |
+
ckpt_dir = "./checkpoints"
|
111 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
112 |
+
ckpt_path = os.path.join(ckpt_dir, "fsg_vit_imagenette_demo.pth")
|
113 |
+
torch.save(model.state_dict(), ckpt_path)
|
114 |
+
print(f"💾 Checkpoint saved to: {ckpt_path}")
|
demo_training_mnist.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Demo training script for Feature Selection Gates (FSG) with ViT on MNIST test set
|
3 |
+
|
4 |
+
This is a minimal demo: we train only on the MNIST test set (resized and converted to 3-channel)
|
5 |
+
for a few epochs to simulate training, save the checkpoint, and allow downstream inference.
|
6 |
+
|
7 |
+
Paper:
|
8 |
+
https://papers.miccai.org/miccai-2024/316-Paper0410.html
|
9 |
+
Code:
|
10 |
+
https://github.com/cosmoimd/feature-selection-gates
|
11 |
+
Contact:
|
12 |
+
giorgio.roffo@gmail.com
|
13 |
+
'''
|
14 |
+
|
15 |
+
import os
|
16 |
+
import warnings
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.optim as optim
|
20 |
+
import psutil
|
21 |
+
from tqdm import tqdm
|
22 |
+
from torchvision import transforms
|
23 |
+
from torchvision.datasets import MNIST
|
24 |
+
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
25 |
+
from torch.utils.data import DataLoader
|
26 |
+
from vit_with_fsg import vit_with_fsg
|
27 |
+
|
28 |
+
warnings.filterwarnings("ignore")
|
29 |
+
|
30 |
+
# Device info
|
31 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
+
print(f"\n🖥️ Using device: {device}")
|
33 |
+
if device.type == "cuda":
|
34 |
+
print(f"🚀 CUDA device: {torch.cuda.get_device_name(0)}")
|
35 |
+
print(f"💾 GPU memory total: {torch.cuda.get_device_properties(0).total_memory / (1024 ** 3):.2f} GB")
|
36 |
+
print(f"🧠 System RAM: {psutil.virtual_memory().total / (1024 ** 3):.2f} GB")
|
37 |
+
|
38 |
+
# Dataset loading
|
39 |
+
print("\n📚 Loading MNIST demo set for demo training (resized to 224x224, 3-channel)...")
|
40 |
+
transform = transforms.Compose([
|
41 |
+
transforms.Resize((224, 224)),
|
42 |
+
transforms.Grayscale(num_output_channels=3),
|
43 |
+
transforms.ToTensor(),
|
44 |
+
transforms.Normalize((0.5,), (0.5,))
|
45 |
+
])
|
46 |
+
|
47 |
+
dataset = MNIST(root="./data", train=False, download=True, transform=transform)
|
48 |
+
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
|
49 |
+
|
50 |
+
# Load ViT backbone and wrap with FSG
|
51 |
+
print("\n📥 Loading pretrained ViT backbone from torchvision...")
|
52 |
+
backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
|
53 |
+
model = vit_with_fsg(backbone).to(device)
|
54 |
+
|
55 |
+
# Prepare optimizer with different LRs for FSG parameters and base model
|
56 |
+
fsg_params = []
|
57 |
+
base_params = []
|
58 |
+
for name, param in model.named_parameters():
|
59 |
+
if 'fsag_rgb_ls' in name:
|
60 |
+
fsg_params.append(param)
|
61 |
+
else:
|
62 |
+
base_params.append(param)
|
63 |
+
|
64 |
+
# Assign a higher LR to FSG parameters, lower to base ViT params
|
65 |
+
lr_base = 1e-4
|
66 |
+
lr_fsg = 5e-4
|
67 |
+
print(f"\n🔧 Optimizer setup:")
|
68 |
+
print(f" 🔹 Base ViT parameters LR: {lr_base}")
|
69 |
+
print(f" 🔸 FSG parameters LR: {lr_fsg}")
|
70 |
+
|
71 |
+
optimizer = optim.AdamW([
|
72 |
+
{"params": base_params, "lr": lr_base},
|
73 |
+
{"params": fsg_params, "lr": lr_fsg}
|
74 |
+
])
|
75 |
+
|
76 |
+
criterion = nn.CrossEntropyLoss()
|
77 |
+
epochs = 3
|
78 |
+
print(f"\n🚀 Starting demo training for {epochs} epochs...")
|
79 |
+
|
80 |
+
model.train()
|
81 |
+
for epoch in range(epochs):
|
82 |
+
steps_demo = 0 # to remove: for demo only
|
83 |
+
running_loss = 0.0
|
84 |
+
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", ncols=100)
|
85 |
+
for inputs, targets in pbar:
|
86 |
+
if steps_demo > 25: # to remove: for demo only
|
87 |
+
break # to remove: for demo only
|
88 |
+
steps_demo += 1 # to remove: for demo only
|
89 |
+
inputs, targets = inputs.to(device), targets.to(device)
|
90 |
+
optimizer.zero_grad()
|
91 |
+
outputs = model(inputs)
|
92 |
+
loss = criterion(outputs, targets)
|
93 |
+
loss.backward()
|
94 |
+
optimizer.step()
|
95 |
+
|
96 |
+
running_loss += loss.item()
|
97 |
+
pbar.set_postfix({"loss": running_loss / (pbar.n + 1e-8)})
|
98 |
+
|
99 |
+
print("\n✅ Training complete.")
|
100 |
+
|
101 |
+
# Save checkpoint
|
102 |
+
ckpt_dir = "./checkpoints"
|
103 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
104 |
+
ckpt_path = os.path.join(ckpt_dir, "fsg_vit_mnist_demo.pth")
|
105 |
+
torch.save(model.state_dict(), ckpt_path)
|
106 |
+
print(f"💾 Checkpoint saved to: {ckpt_path}")
|
vit_with_fsg.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
ViTwithFSG: Vision Transformer wrapper with Feature Selection Gates (FSG)
|
3 |
+
|
4 |
+
This script defines a wrapper class to apply Feature Selection Gates (FSG) to a Vision Transformer (ViT) model.
|
5 |
+
FSG enhances model generalization by introducing sparse, learnable gates on the residual paths of attention and MLP blocks.
|
6 |
+
It is a form of architectural regularization designed for vision tasks and applicable to NLP tasks.
|
7 |
+
|
8 |
+
The method is introduced in:
|
9 |
+
|
10 |
+
@inproceedings{roffo2024FSG,
|
11 |
+
title={Feature Selection Gates with Gradient Routing for Endoscopic Image Computing},
|
12 |
+
author={Giorgio Roffo and Carlo Biffi and Pietro Salvagnini and Andrea Cherubini},
|
13 |
+
booktitle={MICCAI 2024, the 27th International Conference on Medical Image Computing and Computer Assisted Intervention, Marrakech, Morocco, October 2024.},
|
14 |
+
year={2024},
|
15 |
+
organization={Springer}
|
16 |
+
}
|
17 |
+
|
18 |
+
- Publication: https://papers.miccai.org/miccai-2024/316-Paper0410.html
|
19 |
+
- Code: https://github.com/cosmoimd/feature-selection-gates
|
20 |
+
- Contact: giorgio.roffo@gmail.com
|
21 |
+
- Affiliation: Cosmo Intelligent Medical Devices (IMD), Lainate, Italy
|
22 |
+
'''
|
23 |
+
|
24 |
+
# imports
|
25 |
+
import warnings
|
26 |
+
warnings.filterwarnings("ignore")
|
27 |
+
|
28 |
+
import torch
|
29 |
+
import torch.nn as nn
|
30 |
+
from torchvision.models.vision_transformer import VisionTransformer
|
31 |
+
|
32 |
+
class FSGBlock(nn.Module):
|
33 |
+
"""
|
34 |
+
A Transformer encoder block augmented with Feature Selection Gates (FSG).
|
35 |
+
Each residual path (attention and MLP) is weighted element-wise by a learnable sigmoid gate.
|
36 |
+
This promotes sparse activation and serves as a regularization mechanism to avoid overfitting.
|
37 |
+
"""
|
38 |
+
def __init__(self, original_block):
|
39 |
+
super().__init__()
|
40 |
+
self.self_attention = original_block.self_attention # Multi-head self-attention module
|
41 |
+
self.mlp = original_block.mlp # Feedforward network (2-layer MLP)
|
42 |
+
self.ln_1 = original_block.ln_1 # LayerNorm before attention
|
43 |
+
self.ln_2 = original_block.ln_2 # LayerNorm before MLP
|
44 |
+
self.dropout = original_block.dropout # Dropout after attention
|
45 |
+
|
46 |
+
dim = self.ln_1.normalized_shape[0] # Dimensionality of the model
|
47 |
+
|
48 |
+
# FSG: learnable gates (one per channel), initialized with Xavier normal
|
49 |
+
self.fsg_rectifier = nn.Sigmoid()
|
50 |
+
self.fsg_rgb_ls1 = nn.Parameter(torch.empty(dim)) # Gate for attention path
|
51 |
+
self.fsg_rgb_ls2 = nn.Parameter(torch.empty(dim)) # Gate for MLP path
|
52 |
+
nn.init.xavier_normal_(self.fsg_rgb_ls1.unsqueeze(0), gain=nn.init.calculate_gain('sigmoid'))
|
53 |
+
nn.init.xavier_normal_(self.fsg_rgb_ls2.unsqueeze(0), gain=nn.init.calculate_gain('sigmoid'))
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
# Self-attention + gate
|
57 |
+
x_norm = self.ln_1(x)
|
58 |
+
attn_output, _ = self.self_attention(x_norm, x_norm, x_norm, need_weights=False)
|
59 |
+
attn_output = self.dropout(attn_output)
|
60 |
+
fsg_scores_1 = self.fsg_rectifier(self.fsg_rgb_ls1)
|
61 |
+
x = x + attn_output * fsg_scores_1 # Residual connection weighted by gate
|
62 |
+
|
63 |
+
# MLP + gate
|
64 |
+
x_norm = self.ln_2(x)
|
65 |
+
mlp_output = self.mlp(x_norm)
|
66 |
+
fsg_scores_2 = self.fsg_rectifier(self.fsg_rgb_ls2)
|
67 |
+
x = x + mlp_output * fsg_scores_2 # Residual connection weighted by gate
|
68 |
+
|
69 |
+
return x
|
70 |
+
|
71 |
+
class ViTwithFSG(nn.Module):
|
72 |
+
"""
|
73 |
+
Wrapper module that injects FSGBlocks into each Transformer encoder block of a given ViT model.
|
74 |
+
"""
|
75 |
+
def __init__(self, vit_backbone: VisionTransformer):
|
76 |
+
super().__init__()
|
77 |
+
self.vit = vit_backbone
|
78 |
+
for i, blk in enumerate(self.vit.encoder.layers):
|
79 |
+
self.vit.encoder.layers[i] = FSGBlock(blk) # Replace original block with FSGBlock
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
return self.vit(x)
|
83 |
+
|
84 |
+
def vit_with_fsg(vit_backbone: VisionTransformer):
|
85 |
+
"""
|
86 |
+
Factory function that wraps a torchvision VisionTransformer with FSG-enhanced encoder blocks.
|
87 |
+
"""
|
88 |
+
return ViTwithFSG(vit_backbone)
|
89 |
+
|
90 |
+
|
91 |
+
# === Example Usage ===
|
92 |
+
if __name__ == "__main__":
|
93 |
+
import warnings
|
94 |
+
warnings.filterwarnings("ignore", message="Failed to load image Python extension*")
|
95 |
+
|
96 |
+
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
97 |
+
|
98 |
+
print("\n📥 Loading pretrained ViT_B_16 backbone from torchvision...")
|
99 |
+
backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
|
100 |
+
|
101 |
+
print("🔧 Wrapping with Feature Selection Gates (FSG)...")
|
102 |
+
model = vit_with_fsg(vit_backbone=backbone)
|
103 |
+
|
104 |
+
print("🧪 Running dummy input through FSG-augmented ViT...")
|
105 |
+
dummy_input = torch.randn(1, 3, 224, 224)
|
106 |
+
output = model(dummy_input)
|
107 |
+
|
108 |
+
print("✅ Inference completed.")
|
109 |
+
print("📐 Output shape:", output.shape)
|