--- datasets: - ILSVRC/imagenet-21k license: other license_name: nvclv1 license_link: LICENSE pipeline_tag: image-classification library_name: transformers --- [**MambaVision: A Hybrid Mamba-Transformer Vision Backbone**](https://arxiv.org/abs/2407.08083). [Project page](https://github.com/NVlabs/MambaVision) ## Model Overview We have developed the first hybrid model for computer vision which leverages the strengths of Mamba and Transformers. Specifically, our core contribution includes redesigning the Mamba formulation to enhance its capability for efficient modeling of visual features. In addition, we conducted a comprehensive ablation study on the feasibility of integrating Vision Transformers (ViT) with Mamba. Our results demonstrate that equipping the Mamba architecture with several self-attention blocks at the final layers greatly improves the modeling capacity to capture long-range spatial dependencies. Based on our findings, we introduce a family of MambaVision models with a hierarchical architecture to meet various design criteria. ## Model Performance MambaVision-L3-512-21K is pretrained on ImageNet-21K dataset and finetuned on ImageNet-1K at 512 x 512 resolution.
Name Acc@1(%) Acc@5(%) #Params(M) FLOPs(G) Resolution
MambaVision-L3-512-21K 88.1 98.6 739.6 489.1 512x512
In addition, the MambaVision models demonstrate a strong performance by achieving a new SOTA Pareto-front in terms of Top-1 accuracy and throughput.

## Model Usage It is highly recommended to install the requirements for MambaVision by running the following: ```Bash pip install mambavision ``` For each model, we offer two variants for image classification and feature extraction that can be imported with 1 line of code. ### Image Classification In the following example, we demonstrate how MambaVision can be used for image classification. Given the following image from [COCO dataset](https://cocodataset.org/#home) val set as an input:

The following snippet can be used for image classification: ```Python from transformers import AutoModelForImageClassification from PIL import Image from timm.data.transforms_factory import create_transform import requests model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-L3-512-21K", trust_remote_code=True) # eval mode for inference model.cuda().eval() # prepare image for the model url = 'http://images.cocodataset.org/val2017/000000020247.jpg' image = Image.open(requests.get(url, stream=True).raw) input_resolution = (3, 512, 512) # MambaVision supports any input resolutions transform = create_transform(input_size=input_resolution, is_training=False, mean=model.config.mean, std=model.config.std, crop_mode=model.config.crop_mode, crop_pct=model.config.crop_pct) inputs = transform(image).unsqueeze(0).cuda() # model inference outputs = model(inputs) logits = outputs['logits'] predicted_class_idx = logits.argmax(-1).item() print("Predicted class:", model.config.id2label[predicted_class_idx]) ``` The predicted label is ```brown bear, bruin, Ursus arctos.``` ### Feature Extraction MambaVision can also be used as a generic feature extractor. Specifically, we can extract the outputs of each stage of model (4 stages) as well as the final averaged-pool features that are flattened. The following snippet can be used for feature extraction: ```Python from transformers import AutoModel from PIL import Image from timm.data.transforms_factory import create_transform import requests model = AutoModel.from_pretrained("nvidia/MambaVision-L3-512-21K", trust_remote_code=True) # eval mode for inference model.cuda().eval() # prepare image for the model url = 'http://images.cocodataset.org/val2017/000000020247.jpg' image = Image.open(requests.get(url, stream=True).raw) input_resolution = (3, 512, 512) # MambaVision supports any input resolutions transform = create_transform(input_size=input_resolution, is_training=False, mean=model.config.mean, std=model.config.std, crop_mode=model.config.crop_mode, crop_pct=model.config.crop_pct) inputs = transform(image).unsqueeze(0).cuda() # model inference out_avg_pool, features = model(inputs) print("Size of the averaged pool features:", out_avg_pool.size()) # torch.Size([1, 1568]) print("Number of stages in extracted features:", len(features)) # 4 stages print("Size of extracted features in stage 1:", features[0].size()) # torch.Size([1, 196, 128, 128]) print("Size of extracted features in stage 4:", features[3].size()) # torch.Size([1, 1568, 16, 16]) ``` ### License: [NVIDIA Source Code License-NC](https://huggingface.co/nvidia/MambaVision-L3-512-21K/blob/main/LICENSE) ## Results + Pretrained Models ### ImageNet-21K
Name Acc@1(%) Acc@5(%) #Params(M) FLOPs(G) Resolution HF Download
MambaVision-B-21K 84.9 97.5 97.7 15.0 224x224 link model
MambaVision-L-21K 86.1 97.9 227.9 34.9 224x224 link model
MambaVision-L2-512-21K 87.3 98.4 241.5 196.3 512x512 link model
MambaVision-L3-256-21K 87.3 98.3 739.6 122.3 256x256 link model
MambaVision-L3-512-21K 88.1 98.6 739.6 489.1 512x512 link model
### ImageNet-1K
Name Acc@1(%) Acc@5(%) Throughput(Img/Sec) Resolution #Params(M) FLOPs(G) HF Download
MambaVision-T 82.3 96.2 6298 224x224 31.8 4.4 link model
MambaVision-T2 82.7 96.3 5990 224x224 35.1 5.1 link model
MambaVision-S 83.3 96.5 4700 224x224 50.1 7.5 link model
MambaVision-B 84.2 96.9 3670 224x224 97.7 15.0 link model
MambaVision-L 85.0 97.1 2190 224x224 227.9 34.9 link model
MambaVision-L2 85.3 97.2 1021 224x224 241.5 37.5 link model
## Installation We provide a [docker file](./Dockerfile). In addition, assuming that a recent [PyTorch](https://pytorch.org/get-started/locally/) package is installed, the dependencies can be installed by running: ```bash pip install -r requirements.txt ```