import torch from PIL import Image import os from matplotlib import pyplot as plt import argparse from transformers import CLIPProcessor, CLIPModel from transformers import AutoProcessor, AutoModel device = "cuda" if torch.cuda.is_available() else "cpu" def create_gallery(gallery_paths, model, processor): gallery = [] for path in gallery_paths: img = Image.open(os.path.join(args.gallery_path,path)) img_inputs = processor(images=img, return_tensors="pt").to(device) with torch.no_grad(): if args.model == "clip": img_embedding = model.get_image_features(**img_inputs) elif args.model == "dinov2": with torch.no_grad(): outputs = model(**img_inputs) img_embedding = outputs.last_hidden_state.mean(dim=1) img_embedding /= img_embedding.norm(dim=-1, keepdim=True) gallery.append([img_embedding, os.path.join(args.gallery_path, path)]) return gallery def retrieval(args): if args.model == "clip": model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") elif args.model == "dinov2": # Load DINOv2 model model_name = "facebook/dinov2-base" model = AutoModel.from_pretrained(model_name) processor = AutoProcessor.from_pretrained(model_name) gallery_paths = os.listdir(args.gallery_path) query_paths = os.listdir(args.query_path) print("--- Initalizing gallery ---") gallery = create_gallery(gallery_paths, model, processor) for k, query_path in enumerate(query_paths): query_image = Image.open(os.path.join(args.query_path, query_path)) img_inputs = processor(images=query_image, return_tensors="pt").to(device) with torch.no_grad(): if args.model == "clip": query_embedding = model.get_image_features(**img_inputs) elif args.model == "dinov2": with torch.no_grad(): outputs = model(**img_inputs) query_embedding = outputs.last_hidden_state.mean(dim=1) query_embedding /= query_embedding.norm(dim=-1, keepdim=True) fig = plt.figure() plot_length = 11 rank_list = [] gallery_ax = fig.add_subplot(1,plot_length,1) #add query image in the left top place in plot gallery_ax.imshow(query_image) print(f"--- Starting image retrieval for query image: {query_path}") logit_scale = 100 query_normalized = query_embedding / query_embedding.norm(dim=1, keepdim=True) for item in gallery: # normalized features gallery_normalized = item[0] / item[0].norm(dim=1, keepdim=True) # cosine similarity as logits similarity_score = (logit_scale * query_normalized @ gallery_normalized.t()).item() similarity_score = round(similarity_score,3) rank_list.append([similarity_score, item[1]]) # add gallery image with its similarity score to this query image in ranking list rank_list = sorted(rank_list, key=lambda x: x[0], reverse = True) for i in range(2,plot_length): gallery_ax = fig.add_subplot(1,plot_length,i) img = Image.open(rank_list[i][1]) gallery_ax.imshow(img) gallery_ax.set_title('%.1f'% rank_list[i][0], fontsize=8) #add similarity score as title gallery_ax.axis('off') plt.savefig(os.path.join(args.outDir, "plot_"+ str(k)+".jpg")) plt.close() if __name__ == "__main__": # Create an argument parser parser = argparse.ArgumentParser(description="CLIP Image Retriever") # Add arguments parser.add_argument( '--gallery-path', type=str, default="dataset/gallery/", help="Directory containing the gallery images" ) parser.add_argument( '--query-path', type=str, default="dataset/query/", help="Directory containing the query images" ) parser.add_argument( '--outDir', type=str, default="outputs/retrieval_clip", help="Directory containing the output plots" ) parser.add_argument( '--model', type=str, default="clip", help="Model type. i.e clip or dinov2" ) # Parse the arguments args = parser.parse_args() os.makedirs(args.outDir, exist_ok=True) retrieval(args)