File size: 3,763 Bytes
e7fca38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc81bd2
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import json
import os
import numpy as np
from numpy.linalg import norm
import cv2
from io import StringIO, BytesIO
from umap import UMAP
from sklearn.preprocessing import MinMaxScaler
import pandas as pd
from tqdm import tqdm
import base64
from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning, BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM
import requests
from PIL import Image
import torch

url1='http://farm3.staticflickr.com/2519/4126738647_cc436c111b_z.jpg'
cap1='A motorcycle sits parked across from a herd of livestock'

url2='http://farm3.staticflickr.com/2046/2003879022_1b4b466d1d_z.jpg'
cap2='Motorcycle on platform to be worked on in garage'

url3='https://i.natgeofe.com/n/548467d8-c5f1-4551-9f58-6817a8d2c45e/NationalGeographic_2572187_3x2.jpg'
cap3='a cat laying down stretched out near a laptop'

img1 = {
  'flickr_url': url1,
  'caption': cap1,
  'image_path' : './shared_data/motorcycle_1.jpg'
}

img2 = {
    'flickr_url': url2,
    'caption': cap2,
    'image_path' : './shared_data/motorcycle_2.jpg'
}

img3 = {
    'flickr_url' : url3,
    'caption': cap3,
    'image_path' : './shared_data/cat_1.jpg'
}

def bt_embeddings_from_local(text, image):

    model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
    processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")

    processed_inputs  = processor(image, text, padding=True, return_tensors="pt")

    #inputs  = processor(prompt, base64_image, padding=True, return_tensors="pt")
    outputs = model(**processed_inputs)

    cross_modal_embeddings = outputs.cross_embeds
    text_embeddings = outputs.text_embeds
    image_embeddings = outputs.image_embeds
    return {
        'cross_modal_embeddings': cross_modal_embeddings,
        'text_embeddings': text_embeddings,
        'image_embeddings': image_embeddings
    }
    

def bt_scores_with_image_and_text_retrieval():
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"]

    processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-gaudi")
    model = BridgeTowerForImageAndTextRetrieval.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-gaudi")

# forward pass
    scores = dict()
    for text in texts:
        # prepare inputs
        encoding = processor(image, text, return_tensors="pt") 
        outputs = model(**encoding)
        scores[text] = outputs.logits[0,1].item()
    return scores


def bt_with_masked_input():
    url = "http://images.cocodataset.org/val2017/000000360943.jpg"
    image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
    text = "a <mask> looking out of the window"


    processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-gaudi")
    model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-gaudi")

    # prepare inputs
    encoding = processor(image, text, return_tensors="pt")

    # forward pass
    outputs = model(**encoding)

    token_ids = outputs.logits.argmax(dim=-1).squeeze(0).tolist()
    if isinstance(token_ids, list):
        results = processor.tokenizer.decode(token_ids)
    else:
        results = processor.tokenizer.decode([token_ids])

    print(results)
    return results

if __name__ == "__main__":
    #res = bt_embeddingsl()
    #print((res['text_embeddings']))
    for img in [img1, img2, img3]:
        embeddings = bt_embeddings_from_local(img['caption'], Image.open(img['image_path']))
        print(embeddings['cross_modal_embeddings'][0].shape)