Spaces:
Running
Running
import os | |
# Disable JIT | |
os.environ["PYTORCH_JIT"] = "0" | |
from einops import rearrange | |
import gradio as gr | |
import spaces | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image, ImageOps | |
from transformers import AutoModel, CLIPImageProcessor | |
hf_repo = "nvidia/RADIO-L" | |
image_processor = CLIPImageProcessor.from_pretrained(hf_repo) | |
model = AutoModel.from_pretrained(hf_repo, trust_remote_code=True) | |
model.eval() | |
title = """RADIO: Reduce All Domains Into One""" | |
description = """ | |
# RADIO | |
AM-RADIO is a framework to distill Large Vision Foundation models into a single one. | |
RADIO, a new vision foundation model, excels across visual domains, serving as a superior replacement for vision backbones. | |
Integrating CLIP variants, DINOv2, and SAM through distillation, it preserves unique features like text grounding and segmentation correspondence. | |
Outperforming teachers in ImageNet zero-shot (+6.8%), kNN (+2.39%), and linear probing segmentation (+3.8%) and vision-language models (LLaVa 1.5 up to 1.5%), it scales to any resolution, supports non-square images. | |
# Instructions | |
Simply paste an image or pick one from the gallery of examples and then click the "Submit" button. | |
""" | |
inputs = [ | |
gr.Image(type="pil") | |
] | |
examples = [ | |
"samples/IMG_0996.jpeg", | |
"samples/IMG_1061.jpeg", | |
"samples/IMG_1338.jpeg", | |
"samples/IMG_4319.jpeg", | |
"samples/IMG_5104.jpeg", | |
"samples/IMG_5139.jpeg", | |
"samples/IMG_6225.jpeg", | |
"samples/IMG_6814.jpeg", | |
"samples/IMG_7459.jpeg", | |
"samples/IMG_7577.jpeg", | |
"samples/IMG_7687.jpeg", | |
"samples/IMG_9862.jpeg", | |
] | |
outputs = [ | |
gr.Textbox(label="Feature Shape"), | |
gr.Image(), | |
] | |
def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False): | |
# features: (N, C) | |
# m: a hyperparam controlling how many std dev outside for outliers | |
assert len(features.shape) == 2, "features should be (N, C)" | |
reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2] | |
colors = features @ reduction_mat | |
if remove_first_component: | |
colors_min = colors.min(dim=0).values | |
colors_max = colors.max(dim=0).values | |
tmp_colors = (colors - colors_min) / (colors_max - colors_min) | |
fg_mask = tmp_colors[..., 0] < 0.2 | |
reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2] | |
colors = features @ reduction_mat | |
else: | |
fg_mask = torch.ones_like(colors[:, 0]).bool() | |
d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values) | |
mdev = torch.median(d, dim=0).values | |
s = d / mdev | |
try: | |
rins = colors[fg_mask][s[:, 0] < m, 0] | |
gins = colors[fg_mask][s[:, 1] < m, 1] | |
bins = colors[fg_mask][s[:, 2] < m, 2] | |
rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()]) | |
rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()]) | |
except: | |
rins = colors | |
gins = colors | |
bins = colors | |
rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()]) | |
rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()]) | |
return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat) | |
def get_pca_map( | |
feature_map: torch.Tensor, | |
img_size, | |
interpolation="bicubic", | |
return_pca_stats=False, | |
pca_stats=None, | |
): | |
""" | |
feature_map: (1, h, w, C) is the feature map of a single image. | |
""" | |
if feature_map.shape[0] != 1: | |
# make it (1, h, w, C) | |
feature_map = feature_map[None] | |
if pca_stats is None: | |
reduct_mat, color_min, color_max = get_robust_pca( | |
feature_map.reshape(-1, feature_map.shape[-1]) | |
) | |
else: | |
reduct_mat, color_min, color_max = pca_stats | |
pca_color = feature_map @ reduct_mat | |
pca_color = (pca_color - color_min) / (color_max - color_min) | |
pca_color = pca_color.clamp(0, 1) | |
pca_color = F.interpolate( | |
pca_color.permute(0, 3, 1, 2), | |
size=img_size, | |
mode=interpolation, | |
).permute(0, 2, 3, 1) | |
pca_color = pca_color.cpu().numpy().squeeze(0) | |
if return_pca_stats: | |
return pca_color, (reduct_mat, color_min, color_max) | |
return pca_color | |
def pad_image_to_multiple_of_16(image): | |
# Calculate the new dimensions to make them multiples of 16 | |
width, height = image.size | |
new_width = (width + 15) // 16 * 16 | |
new_height = (height + 15) // 16 * 16 | |
# Calculate the padding needed on each side | |
pad_width = new_width - width | |
pad_height = new_height - height | |
left = pad_width // 2 | |
right = pad_width - left | |
top = pad_height // 2 | |
bottom = pad_height - top | |
# Apply the padding | |
padded_image = ImageOps.expand(image, (left, top, right, bottom), fill='black') | |
return padded_image | |
def infer_radio(image): | |
"""Define the function to generate the output.""" | |
model.cuda() | |
image=pad_image_to_multiple_of_16(image) | |
width, height = image.size | |
pixel_values = image_processor(images=image, return_tensors='pt').pixel_values | |
pixel_values = pixel_values.to(torch.bfloat16).cuda() | |
_, features = model(pixel_values) | |
num_rows = height // model.patch_size | |
num_cols = width // model.patch_size | |
features = features.detach() | |
features = rearrange(features, 'b (h w) c -> b h w c', h=num_rows, w=num_cols).float() | |
pca_viz = get_pca_map(features, (height, width), interpolation='bilinear') | |
return f"{features.shape}", pca_viz | |
# Create the Gradio interface | |
demo = gr.Interface( | |
fn=infer_radio, | |
inputs=inputs, | |
examples=examples, | |
outputs=outputs, | |
title=title, | |
description=description | |
) | |
if __name__ == "__main__": | |
demo.launch() | |