File size: 5,304 Bytes
60e35a0
e7fca38
 
 
 
 
 
 
 
 
60e35a0
 
d1fd97e
 
 
e7fca38
 
 
 
 
 
 
 
 
7c6d6c4
 
d1fd97e
7c6d6c4
 
 
e7fca38
7c6d6c4
 
e7fca38
 
 
 
 
 
 
 
d1fd97e
60e35a0
d1fd97e
 
7c6d6c4
 
d1fd97e
 
 
 
 
 
60e35a0
 
e7fca38
 
 
 
 
 
d1fd97e
 
60e35a0
d1fd97e
 
e7fca38
7c6d6c4
 
 
 
 
 
 
 
60e35a0
e7fca38
7c6d6c4
 
60e35a0
e7fca38
 
 
 
7c6d6c4
 
 
 
 
e7fca38
 
7c6d6c4
e7fca38
 
 
 
 
 
 
7c6d6c4
e7fca38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60e35a0
d1fd97e
7c6d6c4
d1fd97e
 
 
 
 
 
 
 
 
7c6d6c4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from os import path
from IPython.display import display
from umap import UMAP
from sklearn.preprocessing import MinMaxScaler
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from s3_data_to_vector_embedding import bt_embeddings_from_local
import random
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from datasets import load_dataset

# prompt templates
templates = [
    'a picture of {}',
    'an image of {}',
    'a nice {}',
    'a beautiful {}',
]
# function helps to prepare list image-text pairs from the first [test_size] data
def data_prep(hf_dataset_name, templates=templates, test_size=1000):
    # load Huggingface dataset by streaming the dataset which doesn’t download anything, and lets you use it instantly
    #dataset = load_dataset(hf_dataset_name, trust_remote_code=True, split='train', streaming=True)

    dataset = load_dataset(hf_dataset_name)
    # split dataset with specific test_size
    train_test_dataset = dataset['train'].train_test_split(test_size=test_size)
    test_dataset = train_test_dataset['test']
    print(test_dataset)
    # get the test dataset
    img_txt_pairs = []
    for i in range(len(test_dataset)):
        img_txt_pairs.append({
            'caption' : templates[random.randint(0, len(templates)-1)],
            'pil_img' : test_dataset[i]['image']
        })
    return img_txt_pairs

        

def load_all_dataset():
    
    car_img_txt_pairs = data_prep("tanganke/stanford_cars", test_size=50)
    cat_img_txt_pairs = data_prep("yashikota/cat-image-dataset", test_size=50)
    
    return cat_img_txt_pairs, car_img_txt_pairs
# compute BridgeTower embeddings for cat image-text pairs
def load_cat_and_car_embeddings():
    # prepare image_text pairs 
    cat_img_txt_pairs, car_img_txt_pairs = load_all_dataset()
    def save_embeddings(embedding, path):
        torch.save(embedding, path)

    def load_embeddings(img_txt_pair):
        pil_img = img_txt_pair['pil_img']
        caption = img_txt_pair['caption']
        return bt_embeddings_from_local(caption, pil_img)
    
    def load_all_embeddings_from_image_text_pairs(img_txt_pairs, file_name):
        embeddings = []
        for img_txt_pair in tqdm(
                            img_txt_pairs, 
                            total=len(img_txt_pairs)
                        ):
            
            embedding = load_embeddings(img_txt_pair)
            print(embedding)
            cross_modal_embeddings = embedding['cross_modal_embeddings'][0].detach().numpy() #this is not the right way to convert tensor to numpy
            #print(cross_modal_embeddings.shape) #<class 'torch.Tensor'>
            #save_embeddings(cross_modal_embeddings, file_name)
            embeddings.append(cross_modal_embeddings)
            return cross_modal_embeddings
    

    cat_embeddings = load_all_embeddings_from_image_text_pairs(cat_img_txt_pairs, './shared_data/cat_embeddings.pt')
    car_embeddings = load_all_embeddings_from_image_text_pairs(car_img_txt_pairs, './shared_data/car_embeddings.pt')
    
    return cat_embeddings, car_embeddings
                        

# function transforms high-dimension vectors to 2D vectors using UMAP
def dimensionality_reduction(embeddings, labels):
     
     
    print(embeddings)
    X_scaled = MinMaxScaler().fit_transform(embeddings.reshape(-1, 1)) # This is not the right way to scale the data
    mapper = UMAP(n_components=2, metric="cosine").fit(X_scaled)
    df_emb = pd.DataFrame(mapper.embedding_, columns=["X", "Y"])
    df_emb["label"] = labels
    print(df_emb)
    return df_emb

def show_umap_visualization():
    def reduce_dimensions():
        cat_embeddings, car_embeddings = load_cat_and_car_embeddings()
        # stacking embeddings of cat and car examples into one numpy array
        all_embeddings = np.concatenate([cat_embeddings, car_embeddings]) # This is not the right way to scale the data

        # prepare labels for the 3 examples
        labels = ['cat'] * len(cat_embeddings) + ['car'] * len(car_embeddings)

        # compute dimensionality reduction for the 3 examples
        reduced_dim_emb = dimensionality_reduction(all_embeddings, labels)
        return reduced_dim_emb

    reduced_dim_emb = reduce_dimensions()
    # Plot the centroids against the cluster
    fig, ax = plt.subplots(figsize=(8,6)) # Set figsize

    sns.set_style("whitegrid", {'axes.grid' : False})
    sns.scatterplot(data=reduced_dim_emb, 
                    x=reduced_dim_emb['X'], 
                    y=reduced_dim_emb['Y'], 
                    hue='label', 
                    palette='bright')
    sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
    plt.title('Scatter plot of images of cats and cars using UMAP')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.show()

def  an_example_of_cat_and_car_pair_data():
    cat_img_txt_pairs, car_img_txt_pairs = load_all_dataset()
    # display an example of a cat image-text pair data
    display(cat_img_txt_pairs[0]['caption'])
    display(cat_img_txt_pairs[0]['pil_img'])

    # display an example of a car image-text pair data
    display(car_img_txt_pairs[0]['caption'])
    display(car_img_txt_pairs[0]['pil_img'])


if __name__ == '__main__':
    show_umap_visualization()