π¬ Feature Selection Gates (FSG) for Vision Transformers (ViT)
This repository implements Feature Selection Gates (FSG) and Gradient Routing (GR) as a modular extension to Vision Transformers. It is based on our paper presented at MICCAI 2024:
Feature Selection Gates with Gradient Routing for Endoscopic Image Computing
Giorgio Roffo, Carlo Biffi, Pietro Salvagnini, Andrea Cherubini
MICCAI 2024, arXiv, GitHub
π§ What Is FSG?
FSG introduces learnable gates on residual branches within Transformer layers. These gates:
- Dynamically select relevant features
- Promote sparse connectivity during training
- Serve as a form of architectural regularization
To stabilize learning, Gradient Routing (GR) performs a dual-pass strategy:
- One forward pass to compute gradients for the base model
- A separate route to update FSG parameters independently
π‘ Key Features
- β
Drop-in: Easily wraps any
torchvision
ViT model (e.g.vit_b_16
,vit_l_16
) - β General-purpose: Use on natural images, medical data, and even token sequences in NLP
- β Regularizes ViTs for low-data regimes (tested on CIFAR-100, endoscopic videos, etc.)
- β No ViT surgery: FSG wraps Transformer layers directly
While this method was originally proposed for polyp size estimation in colonoscopy, it is designed to generalize across:
- 𧬠Medical image analysis
- πΌοΈ General image classification
- π NLP Transformers (e.g. GPT, BERT)
π§ͺ Minimal Example
from torchvision.models import vit_b_16, ViT_B_16_Weights
from vit_with_fsg import vit_with_fsg
import torch
print("π₯ Loading pretrained ViT...")
backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
print("π§ Injecting FSG into backbone...")
model = vit_with_fsg(vit_backbone=backbone)
dummy_input = torch.randn(1, 3, 224, 224)
output = model(dummy_input)
print("β
Output shape:", output.shape)
π§ͺ Demos (Quick Training + Inference)
Dataset | Training Script | Inference Script | Checkpoint Path |
---|---|---|---|
MNIST | demo_training_mnist.py |
demo_inference_mnist.py |
./checkpoints/fsg_vit_mnist_demo.pth |
Imagenette | demo_training_imnet.py |
demo_inference_imnet.py |
./checkpoints/fsg_vit_imagenette_demo.pth |
β οΈ These demos use reduced datasets and epochs to run quickly and demonstrate the API.
π¦ Project Structure
.
βββ vit_with_fsg.py # FSG-ViT integration
βββ demo_training_mnist.py
βββ demo_inference_mnist.py
βββ demo_training_imnet.py
βββ demo_inference_imnet.py
βββ checkpoints/ # Model weights (optional)
βββ README.md # This model card
π Citation
If you use this project, please cite our work:
@inproceedings{roffo2024FSG,
title={Feature Selection Gates with Gradient Routing for Endoscopic Image Computing},
author={Giorgio Roffo and Carlo Biffi and Pietro Salvagnini and Andrea Cherubini},
booktitle={MICCAI 2024, the 27th International Conference on Medical Image Computing and Computer Assisted Intervention, Marrakech, Morocco, October 2024.},
year={2024},
organization={Springer}
}
π¬ Contact
Giorgio Roffo
π§ giorgio.roffo@gmail.com
π’ Cosmo Intelligent Medical Devices (IMD), Lainate, Italy
π github.com/cosmoimd/feature-selection-gates