Spaces:
Running
Running
File size: 4,082 Bytes
6e77bb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import gradio as gr
import torch
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from transformers import CLIPProcessor, CLIPModel, AutoProcessor, AutoModel
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load models and processors
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
dino_model = AutoModel.from_pretrained("facebook/dinov2-base").to(device)
dino_processor = AutoProcessor.from_pretrained("facebook/dinov2-base")
def get_image_embedding(image, model, processor, model_type):
if isinstance(image, str): # Handle file input
image = Image.open(image)
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
if model_type == "clip":
embedding = model.get_image_features(**inputs)
elif model_type == "dinov2":
outputs = model(**inputs)
embedding = outputs.last_hidden_state.mean(dim=1) # Global pooling
embedding /= embedding.norm(dim=-1, keepdim=True) # Normalize
return embedding
def retrieve_images(query_img, gallery_imgs, model, processor, model_type, top_k=10):
query_embedding = get_image_embedding(query_img, model, processor, model_type)
gallery_embeddings = []
for img in gallery_imgs:
emb = get_image_embedding(img, model, processor, model_type)
gallery_embeddings.append((emb, img))
rank_list = []
for emb, img in gallery_embeddings:
similarity_score = (query_embedding @ emb.T).item()
rank_list.append((similarity_score, img))
rank_list = sorted(rank_list, key=lambda x: x[0], reverse=True)[:top_k]
return [img for _, img in rank_list]
def display_results(query_img, gallery_imgs, top_k):
clip_results = retrieve_images(query_img, gallery_imgs, clip_model, clip_processor, "clip", top_k)
dino_results = retrieve_images(query_img, gallery_imgs, dino_model, dino_processor, "dinov2", top_k)
return [query_img] + clip_results, [query_img] + dino_results
def gradio_interface(query_img, gallery_imgs, top_k):
if not isinstance(gallery_imgs, list):
gallery_imgs = [gallery_imgs]
gallery_imgs = [img.name if hasattr(img, 'name') else img for img in gallery_imgs] # Handle NamedString issue
clip_res, dino_res = display_results(query_img, gallery_imgs, top_k)
return clip_res, dino_res
import copy
gallery_path = "dataset/gallery"
filenames = os.listdir(gallery_path)
flag_filenames = [filename for filename in filenames if "flag" in filenames]
tattoo_filenames = [filename for filename in filenames if "tattoo" in filename]
gallery_examples_flags = [os.path.join(gallery_path, filename) for filename in flag_filenames ]
gallery_examples_tattoos = [os.path.join(gallery_path, filename) for filename in tattoo_filenames ]
query_examples = ["dataset/query/american_flag46.jpg", "dataset/query/bird.jpg"]
print(gallery_examples_flags)
print(gallery_examples_tattoos)
demo = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Image(type="pil", label="Query Image"),
gr.File(file_types=["image"], label="Gallery Images", file_count="multiple", elem_id="gallery-files"),
gr.Slider(1, 30, value=10, step=1, label="Top-K Matches"),
],
outputs=[
gr.Gallery(label="CLIP Retrieval Results", elem_id="clip-results", rows=[1], columns=[30]),
gr.Gallery(label="DINOv2 Retrieval Results", elem_id="dino-results", rows=[1], columns=[30]),
],
title="CLIP vs DINOv2 Image Retrieval",
description="Upload a query image and gallery images to see the top-k retrieval results side by side using CLIP and DINOv2.",
examples=[[query_examples[0], gallery_examples_flags, 10], [query_examples[1], gallery_examples_tattoos, 10]],
css="""
#gallery-files {
max-height: 150px;
overflow-y: scroll;
}
#clip-results, #dino-results {
max-height: 150px;
}
"""
)
demo.launch(share=True)
|