File size: 4,682 Bytes
8573586 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
'''
ViTwithFSG: Vision Transformer wrapper with Feature Selection Gates (FSG)
This script defines a wrapper class to apply Feature Selection Gates (FSG) to a Vision Transformer (ViT) model.
FSG enhances model generalization by introducing sparse, learnable gates on the residual paths of attention and MLP blocks.
It is a form of architectural regularization designed for vision tasks and applicable to NLP tasks.
The method is introduced in:
@inproceedings{roffo2024FSG,
title={Feature Selection Gates with Gradient Routing for Endoscopic Image Computing},
author={Giorgio Roffo and Carlo Biffi and Pietro Salvagnini and Andrea Cherubini},
booktitle={MICCAI 2024, the 27th International Conference on Medical Image Computing and Computer Assisted Intervention, Marrakech, Morocco, October 2024.},
year={2024},
organization={Springer}
}
- Publication: https://papers.miccai.org/miccai-2024/316-Paper0410.html
- Code: https://github.com/cosmoimd/feature-selection-gates
- Contact: giorgio.roffo@gmail.com
- Affiliation: Cosmo Intelligent Medical Devices (IMD), Lainate, Italy
'''
# imports
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
from torchvision.models.vision_transformer import VisionTransformer
class FSGBlock(nn.Module):
"""
A Transformer encoder block augmented with Feature Selection Gates (FSG).
Each residual path (attention and MLP) is weighted element-wise by a learnable sigmoid gate.
This promotes sparse activation and serves as a regularization mechanism to avoid overfitting.
"""
def __init__(self, original_block):
super().__init__()
self.self_attention = original_block.self_attention # Multi-head self-attention module
self.mlp = original_block.mlp # Feedforward network (2-layer MLP)
self.ln_1 = original_block.ln_1 # LayerNorm before attention
self.ln_2 = original_block.ln_2 # LayerNorm before MLP
self.dropout = original_block.dropout # Dropout after attention
dim = self.ln_1.normalized_shape[0] # Dimensionality of the model
# FSG: learnable gates (one per channel), initialized with Xavier normal
self.fsg_rectifier = nn.Sigmoid()
self.fsg_rgb_ls1 = nn.Parameter(torch.empty(dim)) # Gate for attention path
self.fsg_rgb_ls2 = nn.Parameter(torch.empty(dim)) # Gate for MLP path
nn.init.xavier_normal_(self.fsg_rgb_ls1.unsqueeze(0), gain=nn.init.calculate_gain('sigmoid'))
nn.init.xavier_normal_(self.fsg_rgb_ls2.unsqueeze(0), gain=nn.init.calculate_gain('sigmoid'))
def forward(self, x):
# Self-attention + gate
x_norm = self.ln_1(x)
attn_output, _ = self.self_attention(x_norm, x_norm, x_norm, need_weights=False)
attn_output = self.dropout(attn_output)
fsg_scores_1 = self.fsg_rectifier(self.fsg_rgb_ls1)
x = x + attn_output * fsg_scores_1 # Residual connection weighted by gate
# MLP + gate
x_norm = self.ln_2(x)
mlp_output = self.mlp(x_norm)
fsg_scores_2 = self.fsg_rectifier(self.fsg_rgb_ls2)
x = x + mlp_output * fsg_scores_2 # Residual connection weighted by gate
return x
class ViTwithFSG(nn.Module):
"""
Wrapper module that injects FSGBlocks into each Transformer encoder block of a given ViT model.
"""
def __init__(self, vit_backbone: VisionTransformer):
super().__init__()
self.vit = vit_backbone
for i, blk in enumerate(self.vit.encoder.layers):
self.vit.encoder.layers[i] = FSGBlock(blk) # Replace original block with FSGBlock
def forward(self, x):
return self.vit(x)
def vit_with_fsg(vit_backbone: VisionTransformer):
"""
Factory function that wraps a torchvision VisionTransformer with FSG-enhanced encoder blocks.
"""
return ViTwithFSG(vit_backbone)
# === Example Usage ===
if __name__ == "__main__":
import warnings
warnings.filterwarnings("ignore", message="Failed to load image Python extension*")
from torchvision.models import vit_b_16, ViT_B_16_Weights
print("\n📥 Loading pretrained ViT_B_16 backbone from torchvision...")
backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
print("🔧 Wrapping with Feature Selection Gates (FSG)...")
model = vit_with_fsg(vit_backbone=backbone)
print("🧪 Running dummy input through FSG-augmented ViT...")
dummy_input = torch.randn(1, 3, 224, 224)
output = model(dummy_input)
print("✅ Inference completed.")
print("📐 Output shape:", output.shape)
|