πŸ”¬ Feature Selection Gates (FSG) for Vision Transformers (ViT)

This repository implements Feature Selection Gates (FSG) and Gradient Routing (GR) as a modular extension to Vision Transformers. It is based on our paper presented at MICCAI 2024:

Feature Selection Gates with Gradient Routing for Endoscopic Image Computing
Giorgio Roffo, Carlo Biffi, Pietro Salvagnini, Andrea Cherubini
MICCAI 2024, arXiv, GitHub


🧠 What Is FSG?

FSG introduces learnable gates on residual branches within Transformer layers. These gates:

  • Dynamically select relevant features
  • Promote sparse connectivity during training
  • Serve as a form of architectural regularization

To stabilize learning, Gradient Routing (GR) performs a dual-pass strategy:

  • One forward pass to compute gradients for the base model
  • A separate route to update FSG parameters independently

πŸ’‘ Key Features

  • βœ… Drop-in: Easily wraps any torchvision ViT model (e.g. vit_b_16, vit_l_16)
  • βœ… General-purpose: Use on natural images, medical data, and even token sequences in NLP
  • βœ… Regularizes ViTs for low-data regimes (tested on CIFAR-100, endoscopic videos, etc.)
  • βœ… No ViT surgery: FSG wraps Transformer layers directly

While this method was originally proposed for polyp size estimation in colonoscopy, it is designed to generalize across:

  • 🧬 Medical image analysis
  • πŸ–ΌοΈ General image classification
  • πŸ“š NLP Transformers (e.g. GPT, BERT)

πŸ§ͺ Minimal Example

from torchvision.models import vit_b_16, ViT_B_16_Weights
from vit_with_fsg import vit_with_fsg
import torch

print("πŸ“₯ Loading pretrained ViT...")
backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)

print("πŸ”§ Injecting FSG into backbone...")
model = vit_with_fsg(vit_backbone=backbone)

dummy_input = torch.randn(1, 3, 224, 224)
output = model(dummy_input)
print("βœ… Output shape:", output.shape)

πŸ§ͺ Demos (Quick Training + Inference)

Dataset Training Script Inference Script Checkpoint Path
MNIST demo_training_mnist.py demo_inference_mnist.py ./checkpoints/fsg_vit_mnist_demo.pth
Imagenette demo_training_imnet.py demo_inference_imnet.py ./checkpoints/fsg_vit_imagenette_demo.pth

⚠️ These demos use reduced datasets and epochs to run quickly and demonstrate the API.


πŸ“¦ Project Structure

.
β”œβ”€β”€ vit_with_fsg.py                  # FSG-ViT integration
β”œβ”€β”€ demo_training_mnist.py
β”œβ”€β”€ demo_inference_mnist.py
β”œβ”€β”€ demo_training_imnet.py
β”œβ”€β”€ demo_inference_imnet.py
β”œβ”€β”€ checkpoints/                    # Model weights (optional)
β”œβ”€β”€ README.md                       # This model card

πŸ“š Citation

If you use this project, please cite our work:

@inproceedings{roffo2024FSG,
   title={Feature Selection Gates with Gradient Routing for Endoscopic Image Computing},
   author={Giorgio Roffo and Carlo Biffi and Pietro Salvagnini and Andrea Cherubini},
   booktitle={MICCAI 2024, the 27th International Conference on Medical Image Computing and Computer Assisted Intervention, Marrakech, Morocco, October 2024.},
   year={2024},
   organization={Springer}
}

πŸ“¬ Contact

Giorgio Roffo
πŸ“§ giorgio.roffo@gmail.com
🏒 Cosmo Intelligent Medical Devices (IMD), Lainate, Italy
πŸ”— github.com/cosmoimd/feature-selection-gates

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support