from transformers import AutoTokenizer, AutoProcessor, CLIPModel from PIL import Image import torch class Clip(): _instance = None def __new__(cls, *args, **kwargs): if not cls._instance: cls._instance = super(Clip, cls).__new__(cls, *args, **kwargs) return cls._instance def __init__(self): if not hasattr(self, 'initialized'): # To ensure __init__ is only called once self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") self.tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda") self.initialized = True def get_model(self): return self.model def get_preprocess(self): return self.processor def get_tokenizer(self): return self.tokenizer def similarity(self, text, image_or_images): text_embedding = self.generate_embedding(text) image_embeddings = self.generate_embedding(image_or_images) if isinstance(image_embeddings, torch.Tensor): # Single image case similarity = (text_embedding @ image_embeddings.T) else: # Multiple images case, image_embeddings is a list of tensors similarity = torch.stack([(text_embedding @ image_embedding.T) for image_embedding in image_embeddings]) # apply softmax to get the similarity score similarity = torch.nn.functional.softmax(similarity, dim=0) return similarity def generate_embedding(self, data): if isinstance(data, Image.Image): image_processed = self.processor(images=data, return_tensors="pt").to("cuda") image_embedding = self.model.get_image_features(**image_processed) return image_embedding elif isinstance(data, list) and all(isinstance(img, Image.Image) for img in data): # Multiple images case image_embeddings = [] for img in data: image_processed = self.processor(images=img, return_tensors="pt").to("cuda") image_embedding = self.model.get_image_features(**image_processed) image_embeddings.append(image_embedding) return image_embeddings elif isinstance(data, str): data = data[0:77] text = self.tokenizer(data, return_tensors="pt").to("cuda") text_embedding = self.model.get_text_features(**text) return text_embedding clip = Clip()