File size: 3,286 Bytes
ad022d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d9878f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from typing import List
from langchain_core.embeddings import Embeddings
import torch
from transformers import (
    BridgeTowerProcessor, 
    BridgeTowerForContrastiveLearning
)
from langchain_core.pydantic_v1 import (
    BaseModel,
)
from lrn_vector_embeddings import bt_embeddings_from_local
from utility import encode_image, bt_embedding_from_prediction_guard
from tqdm import tqdm
from PIL import Image
class BridgeTowerEmbeddings(BaseModel, Embeddings):
    """ BridgeTower embedding model """
        
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed a list of documents using BridgeTower.



        Args:

            texts: The list of texts to embed.



        Returns:

            List of embeddings, one for each text.

        """
        
        embeddings = []
        img = Image.new('RGB', (100, 100))
        for text in texts:
            embedding = bt_embeddings_from_local(text, img)
            embeddings.append(embedding)
        return embeddings
    
    def embed_query(self, text: str) -> List[float]:
        """Embed a query using BridgeTower.

        

        Args:

            text: The text to embed.

        

        Returns:

            Embeddings for the text as a flat list of floats.

        """
        # Get embeddings
        embeddings = self.embed_documents([text])[0]
        
        # If embeddings is a dict, extract the text embeddings
        if isinstance(embeddings, dict):
            embeddings = embeddings["text_embeddings"]
        
        # If embeddings is a nested list or tensor, flatten it
        if isinstance(embeddings, (list, torch.Tensor)) and len(embeddings) == 1:
            embeddings = embeddings[0]
        
        # Convert tensor to list if needed
        if torch.is_tensor(embeddings):
            embeddings = embeddings.detach().tolist()
            
        return embeddings

    
    def embed_image_text_pairs(self, texts: List[str], images: List[str], batch_size=2) -> List[List[float]]:
        """Embed a list of image-text pairs using BridgeTower.



        Args:

            texts: The list of texts to embed.

            images: The list of path-to-images to embed

            batch_size: the batch size to process, default to 2

        Returns:

            List of embeddings, one for each image-text pairs.

        """

        # the length of texts must be equal to the length of images
        assert len(texts)==len(images), "the len of captions should be equal to the len of images"
        
        processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
        model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
    
 
                
        embeddings = []
        for path_to_img, text in tqdm(zip(images, texts), total=len(texts)):
            inputs = processor(text=[text], images=[Image.open(path_to_img)], return_tensors="pt")
            outputs = model(**inputs)
            # Get embeddings and convert to list
            embedding = outputs.text_embeds.detach().numpy().tolist()[0]
            embeddings.append(embedding)
        
        return embeddings