File size: 2,028 Bytes
e7fca38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from numpy.linalg import norm
from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning
import torch
from PIL import Image


url1='http://farm3.staticflickr.com/2519/4126738647_cc436c111b_z.jpg'
cap1='A motorcycle sits parked across from a herd of livestock'

url2='http://farm3.staticflickr.com/2046/2003879022_1b4b466d1d_z.jpg'
cap2='Motorcycle on platform to be worked on in garage'

url3='https://i.natgeofe.com/n/548467d8-c5f1-4551-9f58-6817a8d2c45e/NationalGeographic_2572187_3x2.jpg'
cap3='a cat laying down stretched out near a laptop'

img1 = {
  'flickr_url': url1,
  'caption': cap1,
  'image_path' : './shared_data/motorcycle_1.jpg',
  'tensor_path' : './shared_data/motorcycle_1'
}

img2 = {
    'flickr_url': url2,
    'caption': cap2,
    'image_path' : './shared_data/motorcycle_2.jpg',
    'tensor_path' : './shared_data/motorcycle_2'
}

img3 = {
    'flickr_url' : url3,
    'caption': cap3,
    'image_path' : './shared_data/cat_1.jpg',
    'tensor_path' : './shared_data/cat_1'
}

def bt_embeddings_from_local(text, image):

    model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
    processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")

    processed_inputs  = processor(image, text, padding=True, return_tensors="pt")

    outputs = model(**processed_inputs)

    cross_modal_embeddings = outputs.cross_embeds
    text_embeddings = outputs.text_embeds
    image_embeddings = outputs.image_embeds
    return {
        'cross_modal_embeddings': cross_modal_embeddings,
        'text_embeddings': text_embeddings,
        'image_embeddings': image_embeddings
    }

def save_embeddings():
    for img in [img1, img2, img3]:
        embedding = bt_embeddings_from_local(img['caption'], Image.open(img['image_path']))
        print(embedding['cross_modal_embeddings'][0].shape) #<class 'torch.Tensor'>
        torch.save(embedding['cross_modal_embeddings'][0], img['tensor_path'] + '.pt')