File size: 4,082 Bytes
6e77bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import torch
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from transformers import CLIPProcessor, CLIPModel, AutoProcessor, AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load models and processors
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
dino_model = AutoModel.from_pretrained("facebook/dinov2-base").to(device)
dino_processor = AutoProcessor.from_pretrained("facebook/dinov2-base")

def get_image_embedding(image, model, processor, model_type):
    if isinstance(image, str):  # Handle file input
        image = Image.open(image)
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        if model_type == "clip":
            embedding = model.get_image_features(**inputs)
        elif model_type == "dinov2":
            outputs = model(**inputs)
            embedding = outputs.last_hidden_state.mean(dim=1)  # Global pooling
    embedding /= embedding.norm(dim=-1, keepdim=True)  # Normalize
    return embedding

def retrieve_images(query_img, gallery_imgs, model, processor, model_type, top_k=10):
    query_embedding = get_image_embedding(query_img, model, processor, model_type)
    
    gallery_embeddings = []
    for img in gallery_imgs:
        emb = get_image_embedding(img, model, processor, model_type)
        gallery_embeddings.append((emb, img))
    
    rank_list = []
    for emb, img in gallery_embeddings:
        similarity_score = (query_embedding @ emb.T).item()
        rank_list.append((similarity_score, img))
    
    rank_list = sorted(rank_list, key=lambda x: x[0], reverse=True)[:top_k]
    return [img for _, img in rank_list]

def display_results(query_img, gallery_imgs, top_k):
    clip_results = retrieve_images(query_img, gallery_imgs, clip_model, clip_processor, "clip", top_k)
    dino_results = retrieve_images(query_img, gallery_imgs, dino_model, dino_processor, "dinov2", top_k)
    return [query_img] + clip_results, [query_img] + dino_results

def gradio_interface(query_img, gallery_imgs, top_k):
    if not isinstance(gallery_imgs, list):
        gallery_imgs = [gallery_imgs]
    gallery_imgs = [img.name if hasattr(img, 'name') else img for img in gallery_imgs]  # Handle NamedString issue
    clip_res, dino_res = display_results(query_img, gallery_imgs, top_k)
    return clip_res, dino_res

import copy 

gallery_path = "dataset/gallery"
filenames = os.listdir(gallery_path)

flag_filenames = [filename for filename in filenames if "flag" in filenames] 
tattoo_filenames = [filename for filename in filenames if "tattoo" in filename] 

gallery_examples_flags = [os.path.join(gallery_path, filename) for filename in flag_filenames ]
gallery_examples_tattoos =  [os.path.join(gallery_path, filename) for filename in tattoo_filenames ]

query_examples = ["dataset/query/american_flag46.jpg", "dataset/query/bird.jpg"]

print(gallery_examples_flags)
print(gallery_examples_tattoos)

demo = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Image(type="pil", label="Query Image"),
        gr.File(file_types=["image"], label="Gallery Images", file_count="multiple", elem_id="gallery-files"),
        gr.Slider(1, 30, value=10, step=1, label="Top-K Matches"),
    ],
    outputs=[
        gr.Gallery(label="CLIP Retrieval Results", elem_id="clip-results", rows=[1], columns=[30]),
        gr.Gallery(label="DINOv2 Retrieval Results", elem_id="dino-results", rows=[1], columns=[30]),
    ],
    title="CLIP vs DINOv2 Image Retrieval",
    description="Upload a query image and gallery images to see the top-k retrieval results side by side using CLIP and DINOv2.",
    examples=[[query_examples[0], gallery_examples_flags, 10], [query_examples[1], gallery_examples_tattoos, 10]],
    css="""
    #gallery-files {
        max-height: 150px;
        overflow-y: scroll;
    }
    #clip-results, #dino-results {
        max-height: 150px;
    }
    """
)

demo.launch(share=True)