File size: 2,654 Bytes
173ea2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from transformers import AutoTokenizer, AutoProcessor, CLIPModel
from PIL import Image
import torch


class Clip():
    _instance = None

    def __new__(cls, *args, **kwargs):
        if not cls._instance:
            cls._instance = super(Clip, cls).__new__(cls, *args, **kwargs)
        return cls._instance

    def __init__(self):
        if not hasattr(self, 'initialized'):  # To ensure __init__ is only called once
            self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
            self.tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
            self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
            self.initialized = True

    def get_model(self):
        return self.model

    def get_preprocess(self):
        return self.processor

    def get_tokenizer(self):
        return self.tokenizer

    def similarity(self, text, image_or_images):
        text_embedding = self.generate_embedding(text)
        image_embeddings = self.generate_embedding(image_or_images)

        if isinstance(image_embeddings, torch.Tensor):
            # Single image case
            similarity = (text_embedding @ image_embeddings.T)
        else:
            # Multiple images case, image_embeddings is a list of tensors
            similarity = torch.stack([(text_embedding @ image_embedding.T) for image_embedding in image_embeddings])
            # apply softmax to get the similarity score
            similarity = torch.nn.functional.softmax(similarity, dim=0)
        return similarity

    def generate_embedding(self, data):
        if isinstance(data, Image.Image):
            image_processed = self.processor(images=data, return_tensors="pt").to("cuda")
            image_embedding = self.model.get_image_features(**image_processed)
            return image_embedding

        elif isinstance(data, list) and all(isinstance(img, Image.Image) for img in data):
            # Multiple images case
            image_embeddings = []
            for img in data:
                image_processed = self.processor(images=img, return_tensors="pt").to("cuda")
                image_embedding = self.model.get_image_features(**image_processed)
                image_embeddings.append(image_embedding)
            return image_embeddings

        elif isinstance(data, str):
            data = data[0:77]
            text = self.tokenizer(data, return_tensors="pt").to("cuda")
            text_embedding = self.model.get_text_features(**text)
            return text_embedding

clip = Clip()