File size: 3,667 Bytes
405f2d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

from .utils import (
    get_text_attributes,
    get_top_5_predictions,
    get_transformed_image,
    plotly_express_horizontal_bar_plot,
    translate_labels,
    bert_tokenizer
)

import streamlit as st
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt

from session import _get_state


from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
    FlaxCLIPVisionBertForMaskedLM,
)

def softmax(logits):
    return np.exp(logits) / np.sum(np.exp(logits), axis=0)

def app():
    state = _get_state()

    @st.cache(persist=False)
    def predict(transformed_image, caption_inputs):
        outputs = state.model(pixel_values=transformed_image, **caption_inputs)
        indices = np.where(caption_inputs['input_ids']==bert_tokenizer.mask_token_id)
        preds = outputs.logits[indices][0]
        sorted_indices = np.argsort(preds)[::-1] # Get reverse sorted scores
        top_5_indices = sorted_indices[:5]
        top_5_tokens = bert_tokenizer.convert_ids_to_tokens(top_5_indices)
        top_5_scores = np.array(preds[top_5_indices])
        return top_5_tokens, top_5_scores


    @st.cache(persist=False)
    def load_model(ckpt):
        return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt)

    mlm_checkpoints = ['flax-community/clip-vision-bert-cc12m-70k']
    dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t")

    first_index = 20
    # Init Session State
    if state.image_file is None:
        state.image_file = dummy_data.loc[first_index, "image_file"]
        caption = dummy_data.loc[first_index, "caption"].strip("- ")
        ids = bert_tokenizer(caption)
        ids[np.random.randint(0, len(ids))] = bert_tokenizer.mask_token_id
        state.caption = bert_tokenizer.decode(ids)
        state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]

        image_path = os.path.join("cc12m_data/images_vqa", state.image_file)
        image = plt.imread(image_path)
        state.image = image

    if state.model is None:
        # Display Top-5 Predictions
        with st.spinner("Loading model..."):
            state.model = load_model(mlm_checkpoints[0])

    if st.button(
        "Get a random example",
        help="Get a random example from the 100 `seeded` image-text pairs.",
    ):
        sample = dummy_data.sample(1).reset_index()
        state.image_file = sample.loc[0, "image_file"]
        caption = sample.loc[0, "caption"].strip("- ")
        ids = bert_tokenizer(caption)
        ids[np.random.randint(0, len(ids))] = bert_tokenizer.mask_token_id
        state.caption = bert_tokenizer.decode(ids)
        state.caption_lang_id = sample.loc[0, "lang_id"]

        image_path = os.path.join("cc12m_data/images_vqa", state.image_file)
        image = plt.imread(image_path)
        state.image = image

    transformed_image = get_transformed_image(state.image)

    new_col1, new_col2 = st.beta_columns([5, 5])

    # Display Image
    new_col1.image(state.image, use_column_width="always")


    # Display caption
    new_col2.write("Write your text with exactly one [MASK] token.")
    caption = new_col2.text_input(
        label="Text",
        value=state.caption,
        help="Type your masked caption regarding the image above in one of the four languages.",
    )

    caption_inputs = get_text_attributes(caption)

    # Display Top-5 Predictions

    with st.spinner("Predicting..."):
        logits = predict(transformed_image, dict(caption_inputs))
    logits = softmax(logits)
    labels, values = get_top_5_predictions(logits)
    fig = plotly_express_horizontal_bar_plot(values, labels)
    st.plotly_chart(fig, use_container_width=True)