CLIPvsDINOv2 / scripts /image_retrieval.py
YagmurCA's picture
Upload 101 files
6e77bb4 verified
import torch
from PIL import Image
import os
from matplotlib import pyplot as plt
import argparse
from transformers import CLIPProcessor, CLIPModel
from transformers import AutoProcessor, AutoModel
device = "cuda" if torch.cuda.is_available() else "cpu"
def create_gallery(gallery_paths, model, processor):
gallery = []
for path in gallery_paths:
img = Image.open(os.path.join(args.gallery_path,path))
img_inputs = processor(images=img, return_tensors="pt").to(device)
with torch.no_grad():
if args.model == "clip":
img_embedding = model.get_image_features(**img_inputs)
elif args.model == "dinov2":
with torch.no_grad():
outputs = model(**img_inputs)
img_embedding = outputs.last_hidden_state.mean(dim=1)
img_embedding /= img_embedding.norm(dim=-1, keepdim=True)
gallery.append([img_embedding, os.path.join(args.gallery_path, path)])
return gallery
def retrieval(args):
if args.model == "clip":
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
elif args.model == "dinov2":
# Load DINOv2 model
model_name = "facebook/dinov2-base"
model = AutoModel.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)
gallery_paths = os.listdir(args.gallery_path)
query_paths = os.listdir(args.query_path)
print("--- Initalizing gallery ---")
gallery = create_gallery(gallery_paths, model, processor)
for k, query_path in enumerate(query_paths):
query_image = Image.open(os.path.join(args.query_path, query_path))
img_inputs = processor(images=query_image, return_tensors="pt").to(device)
with torch.no_grad():
if args.model == "clip":
query_embedding = model.get_image_features(**img_inputs)
elif args.model == "dinov2":
with torch.no_grad():
outputs = model(**img_inputs)
query_embedding = outputs.last_hidden_state.mean(dim=1)
query_embedding /= query_embedding.norm(dim=-1, keepdim=True)
fig = plt.figure()
plot_length = 11
rank_list = []
gallery_ax = fig.add_subplot(1,plot_length,1) #add query image in the left top place in plot
gallery_ax.imshow(query_image)
print(f"--- Starting image retrieval for query image: {query_path}")
logit_scale = 100
query_normalized = query_embedding / query_embedding.norm(dim=1, keepdim=True)
for item in gallery:
# normalized features
gallery_normalized = item[0] / item[0].norm(dim=1, keepdim=True)
# cosine similarity as logits
similarity_score = (logit_scale * query_normalized @ gallery_normalized.t()).item()
similarity_score = round(similarity_score,3)
rank_list.append([similarity_score, item[1]]) # add gallery image with its similarity score to this query image in ranking list
rank_list = sorted(rank_list, key=lambda x: x[0], reverse = True)
for i in range(2,plot_length):
gallery_ax = fig.add_subplot(1,plot_length,i)
img = Image.open(rank_list[i][1])
gallery_ax.imshow(img)
gallery_ax.set_title('%.1f'% rank_list[i][0], fontsize=8) #add similarity score as title
gallery_ax.axis('off')
plt.savefig(os.path.join(args.outDir, "plot_"+ str(k)+".jpg"))
plt.close()
if __name__ == "__main__":
# Create an argument parser
parser = argparse.ArgumentParser(description="CLIP Image Retriever")
# Add arguments
parser.add_argument(
'--gallery-path',
type=str,
default="dataset/gallery/",
help="Directory containing the gallery images"
)
parser.add_argument(
'--query-path',
type=str,
default="dataset/query/",
help="Directory containing the query images"
)
parser.add_argument(
'--outDir',
type=str,
default="outputs/retrieval_clip",
help="Directory containing the output plots"
)
parser.add_argument(
'--model',
type=str,
default="clip",
help="Model type. i.e clip or dinov2"
)
# Parse the arguments
args = parser.parse_args()
os.makedirs(args.outDir, exist_ok=True)
retrieval(args)