Spaces:
Running
Running
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from transformers import CLIPProcessor, CLIPModel, AutoProcessor, AutoModel
|
8 |
+
|
9 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
+
|
11 |
+
# Load models and processors
|
12 |
+
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
13 |
+
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
14 |
+
dino_model = AutoModel.from_pretrained("facebook/dinov2-base").to(device)
|
15 |
+
dino_processor = AutoProcessor.from_pretrained("facebook/dinov2-base")
|
16 |
+
|
17 |
+
def get_image_embedding(image, model, processor, model_type):
|
18 |
+
if isinstance(image, str): # Handle file input
|
19 |
+
image = Image.open(image)
|
20 |
+
inputs = processor(images=image, return_tensors="pt").to(device)
|
21 |
+
with torch.no_grad():
|
22 |
+
if model_type == "clip":
|
23 |
+
embedding = model.get_image_features(**inputs)
|
24 |
+
elif model_type == "dinov2":
|
25 |
+
outputs = model(**inputs)
|
26 |
+
embedding = outputs.last_hidden_state.mean(dim=1) # Global pooling
|
27 |
+
embedding /= embedding.norm(dim=-1, keepdim=True) # Normalize
|
28 |
+
return embedding
|
29 |
+
|
30 |
+
def retrieve_images(query_img, gallery_imgs, model, processor, model_type, top_k=10):
|
31 |
+
query_embedding = get_image_embedding(query_img, model, processor, model_type)
|
32 |
+
|
33 |
+
gallery_embeddings = []
|
34 |
+
for img in gallery_imgs:
|
35 |
+
emb = get_image_embedding(img, model, processor, model_type)
|
36 |
+
gallery_embeddings.append((emb, img))
|
37 |
+
|
38 |
+
rank_list = []
|
39 |
+
for emb, img in gallery_embeddings:
|
40 |
+
similarity_score = (query_embedding @ emb.T).item()
|
41 |
+
rank_list.append((similarity_score, img))
|
42 |
+
|
43 |
+
rank_list = sorted(rank_list, key=lambda x: x[0], reverse=True)[:top_k]
|
44 |
+
return [img for _, img in rank_list]
|
45 |
+
|
46 |
+
def display_results(query_img, gallery_imgs, top_k):
|
47 |
+
clip_results = retrieve_images(query_img, gallery_imgs, clip_model, clip_processor, "clip", top_k)
|
48 |
+
dino_results = retrieve_images(query_img, gallery_imgs, dino_model, dino_processor, "dinov2", top_k)
|
49 |
+
return [query_img] + clip_results, [query_img] + dino_results
|
50 |
+
|
51 |
+
def gradio_interface(query_img, gallery_imgs, top_k):
|
52 |
+
if not isinstance(gallery_imgs, list):
|
53 |
+
gallery_imgs = [gallery_imgs]
|
54 |
+
gallery_imgs = [img.name if hasattr(img, 'name') else img for img in gallery_imgs] # Handle NamedString issue
|
55 |
+
clip_res, dino_res = display_results(query_img, gallery_imgs, top_k)
|
56 |
+
return clip_res, dino_res
|
57 |
+
|
58 |
+
import copy
|
59 |
+
|
60 |
+
gallery_path = "dataset/gallery"
|
61 |
+
filenames = os.listdir(gallery_path)
|
62 |
+
|
63 |
+
flag_filenames = [filename for filename in filenames if "flag" in filenames]
|
64 |
+
tattoo_filenames = [filename for filename in filenames if "tattoo" in filename]
|
65 |
+
|
66 |
+
gallery_examples_flags = [os.path.join(gallery_path, filename) for filename in flag_filenames ]
|
67 |
+
gallery_examples_tattoos = [os.path.join(gallery_path, filename) for filename in tattoo_filenames ]
|
68 |
+
|
69 |
+
query_examples = ["dataset/query/american_flag46.jpg", "dataset/query/bird.jpg"]
|
70 |
+
|
71 |
+
print(gallery_examples_flags)
|
72 |
+
print(gallery_examples_tattoos)
|
73 |
+
|
74 |
+
demo = gr.Interface(
|
75 |
+
fn=gradio_interface,
|
76 |
+
inputs=[
|
77 |
+
gr.Image(type="pil", label="Query Image"),
|
78 |
+
gr.File(file_types=["image"], label="Gallery Images", file_count="multiple", elem_id="gallery-files"),
|
79 |
+
gr.Slider(1, 30, value=10, step=1, label="Top-K Matches"),
|
80 |
+
],
|
81 |
+
outputs=[
|
82 |
+
gr.Gallery(label="CLIP Retrieval Results", elem_id="clip-results", rows=[1], columns=[30]),
|
83 |
+
gr.Gallery(label="DINOv2 Retrieval Results", elem_id="dino-results", rows=[1], columns=[30]),
|
84 |
+
],
|
85 |
+
title="CLIP vs DINOv2 Image Retrieval",
|
86 |
+
description="Upload a query image and gallery images to see the top-k retrieval results side by side using CLIP and DINOv2.",
|
87 |
+
examples=[[query_examples[0], gallery_examples_flags, 10], [query_examples[1], gallery_examples_tattoos, 10]],
|
88 |
+
css="""
|
89 |
+
#gallery-files {
|
90 |
+
max-height: 150px;
|
91 |
+
overflow-y: scroll;
|
92 |
+
}
|
93 |
+
#clip-results, #dino-results {
|
94 |
+
max-height: 150px;
|
95 |
+
}
|
96 |
+
"""
|
97 |
+
)
|
98 |
+
|
99 |
+
demo.launch(share=True)
|