YagmurCA commited on
Commit
85bfc65
·
verified ·
1 Parent(s): 550d274

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import os
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from transformers import CLIPProcessor, CLIPModel, AutoProcessor, AutoModel
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # Load models and processors
12
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
13
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
14
+ dino_model = AutoModel.from_pretrained("facebook/dinov2-base").to(device)
15
+ dino_processor = AutoProcessor.from_pretrained("facebook/dinov2-base")
16
+
17
+ def get_image_embedding(image, model, processor, model_type):
18
+ if isinstance(image, str): # Handle file input
19
+ image = Image.open(image)
20
+ inputs = processor(images=image, return_tensors="pt").to(device)
21
+ with torch.no_grad():
22
+ if model_type == "clip":
23
+ embedding = model.get_image_features(**inputs)
24
+ elif model_type == "dinov2":
25
+ outputs = model(**inputs)
26
+ embedding = outputs.last_hidden_state.mean(dim=1) # Global pooling
27
+ embedding /= embedding.norm(dim=-1, keepdim=True) # Normalize
28
+ return embedding
29
+
30
+ def retrieve_images(query_img, gallery_imgs, model, processor, model_type, top_k=10):
31
+ query_embedding = get_image_embedding(query_img, model, processor, model_type)
32
+
33
+ gallery_embeddings = []
34
+ for img in gallery_imgs:
35
+ emb = get_image_embedding(img, model, processor, model_type)
36
+ gallery_embeddings.append((emb, img))
37
+
38
+ rank_list = []
39
+ for emb, img in gallery_embeddings:
40
+ similarity_score = (query_embedding @ emb.T).item()
41
+ rank_list.append((similarity_score, img))
42
+
43
+ rank_list = sorted(rank_list, key=lambda x: x[0], reverse=True)[:top_k]
44
+ return [img for _, img in rank_list]
45
+
46
+ def display_results(query_img, gallery_imgs, top_k):
47
+ clip_results = retrieve_images(query_img, gallery_imgs, clip_model, clip_processor, "clip", top_k)
48
+ dino_results = retrieve_images(query_img, gallery_imgs, dino_model, dino_processor, "dinov2", top_k)
49
+ return [query_img] + clip_results, [query_img] + dino_results
50
+
51
+ def gradio_interface(query_img, gallery_imgs, top_k):
52
+ if not isinstance(gallery_imgs, list):
53
+ gallery_imgs = [gallery_imgs]
54
+ gallery_imgs = [img.name if hasattr(img, 'name') else img for img in gallery_imgs] # Handle NamedString issue
55
+ clip_res, dino_res = display_results(query_img, gallery_imgs, top_k)
56
+ return clip_res, dino_res
57
+
58
+ import copy
59
+
60
+ gallery_path = "dataset/gallery"
61
+ filenames = os.listdir(gallery_path)
62
+
63
+ flag_filenames = [filename for filename in filenames if "flag" in filenames]
64
+ tattoo_filenames = [filename for filename in filenames if "tattoo" in filename]
65
+
66
+ gallery_examples_flags = [os.path.join(gallery_path, filename) for filename in flag_filenames ]
67
+ gallery_examples_tattoos = [os.path.join(gallery_path, filename) for filename in tattoo_filenames ]
68
+
69
+ query_examples = ["dataset/query/american_flag46.jpg", "dataset/query/bird.jpg"]
70
+
71
+ print(gallery_examples_flags)
72
+ print(gallery_examples_tattoos)
73
+
74
+ demo = gr.Interface(
75
+ fn=gradio_interface,
76
+ inputs=[
77
+ gr.Image(type="pil", label="Query Image"),
78
+ gr.File(file_types=["image"], label="Gallery Images", file_count="multiple", elem_id="gallery-files"),
79
+ gr.Slider(1, 30, value=10, step=1, label="Top-K Matches"),
80
+ ],
81
+ outputs=[
82
+ gr.Gallery(label="CLIP Retrieval Results", elem_id="clip-results", rows=[1], columns=[30]),
83
+ gr.Gallery(label="DINOv2 Retrieval Results", elem_id="dino-results", rows=[1], columns=[30]),
84
+ ],
85
+ title="CLIP vs DINOv2 Image Retrieval",
86
+ description="Upload a query image and gallery images to see the top-k retrieval results side by side using CLIP and DINOv2.",
87
+ examples=[[query_examples[0], gallery_examples_flags, 10], [query_examples[1], gallery_examples_tattoos, 10]],
88
+ css="""
89
+ #gallery-files {
90
+ max-height: 150px;
91
+ overflow-y: scroll;
92
+ }
93
+ #clip-results, #dino-results {
94
+ max-height: 150px;
95
+ }
96
+ """
97
+ )
98
+
99
+ demo.launch(share=True)