File size: 4,567 Bytes
6e77bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from PIL import Image
import os
from matplotlib import pyplot as plt 
import argparse
from transformers import CLIPProcessor, CLIPModel
from transformers import AutoProcessor, AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"


def create_gallery(gallery_paths, model, processor):
    gallery = []

    for path in gallery_paths:
        img = Image.open(os.path.join(args.gallery_path,path))
        img_inputs = processor(images=img, return_tensors="pt").to(device)
        with torch.no_grad():
            if args.model == "clip":
                img_embedding = model.get_image_features(**img_inputs)
            elif args.model == "dinov2":
                with torch.no_grad():
                    outputs = model(**img_inputs)
                img_embedding = outputs.last_hidden_state.mean(dim=1)  
            img_embedding /= img_embedding.norm(dim=-1, keepdim=True)

        gallery.append([img_embedding, os.path.join(args.gallery_path, path)])

    return gallery

def retrieval(args):
    if args.model == "clip":
        model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
        processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    elif args.model == "dinov2":
        # Load DINOv2 model
        model_name = "facebook/dinov2-base"
        model = AutoModel.from_pretrained(model_name)
        processor = AutoProcessor.from_pretrained(model_name)


    gallery_paths = os.listdir(args.gallery_path)
    query_paths = os.listdir(args.query_path)

    print("--- Initalizing gallery ---")
    gallery = create_gallery(gallery_paths, model, processor)

    for k, query_path in enumerate(query_paths):
        query_image = Image.open(os.path.join(args.query_path, query_path))
        img_inputs = processor(images=query_image, return_tensors="pt").to(device)
        with torch.no_grad():
            if args.model == "clip":
                query_embedding = model.get_image_features(**img_inputs)
            elif args.model == "dinov2":
                with torch.no_grad():
                    outputs = model(**img_inputs)
                query_embedding = outputs.last_hidden_state.mean(dim=1)  
            query_embedding /= query_embedding.norm(dim=-1, keepdim=True)


        fig = plt.figure()
        plot_length = 11
        rank_list = []

        gallery_ax = fig.add_subplot(1,plot_length,1) #add query image in the left top place in plot
        gallery_ax.imshow(query_image)

        print(f"--- Starting image retrieval for query image: {query_path}")
        logit_scale = 100
        query_normalized = query_embedding / query_embedding.norm(dim=1, keepdim=True)

        for item in gallery:
             # normalized features
            gallery_normalized = item[0] / item[0].norm(dim=1, keepdim=True)
            # cosine similarity as logits
            similarity_score = (logit_scale * query_normalized @ gallery_normalized.t()).item()
            similarity_score = round(similarity_score,3)
            rank_list.append([similarity_score, item[1]])  # add gallery image with its similarity score to this query image in ranking list 
    
        rank_list = sorted(rank_list, key=lambda x: x[0], reverse = True)

    
        for i in range(2,plot_length): 
            gallery_ax = fig.add_subplot(1,plot_length,i)
            img = Image.open(rank_list[i][1])
            gallery_ax.imshow(img)
            gallery_ax.set_title('%.1f'% rank_list[i][0], fontsize=8) #add similarity score as title
            gallery_ax.axis('off')
        plt.savefig(os.path.join(args.outDir, "plot_"+ str(k)+".jpg"))
        plt.close()



if __name__ == "__main__":
    # Create an argument parser
    parser = argparse.ArgumentParser(description="CLIP Image Retriever")

    # Add arguments
    parser.add_argument(
        '--gallery-path', 
        type=str, 
        default="dataset/gallery/",
        help="Directory containing the gallery images"
    )
    parser.add_argument(
        '--query-path', 
        type=str, 
        default="dataset/query/",
        help="Directory containing the query images"
    )
    parser.add_argument(
        '--outDir', 
        type=str, 
        default="outputs/retrieval_clip",
        help="Directory containing the output plots"
    )
    parser.add_argument(
        '--model', 
        type=str, 
        default="clip",
        help="Model type. i.e clip or dinov2"
    )

    # Parse the arguments
    args = parser.parse_args()

    os.makedirs(args.outDir, exist_ok=True)
    retrieval(args)