File size: 3,525 Bytes
20dcad7
 
 
b0e280f
20dcad7
4cea813
b0e280f
 
 
4cea813
20dcad7
 
b0e280f
20dcad7
d14c041
 
 
 
 
 
 
 
 
 
 
 
 
 
20dcad7
 
 
 
 
 
 
 
 
9c68392
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20dcad7
 
 
 
 
 
 
9c68392
20dcad7
 
 
 
b0e280f
 
 
 
 
 
 
 
 
 
 
 
9c68392
b0e280f
20dcad7
 
 
 
 
b0e280f
 
 
 
 
 
 
 
9c68392
b0e280f
 
 
20dcad7
b0e280f
20dcad7
4cea813
b0e280f
 
 
 
 
 
 
 
9c68392
 
 
 
 
 
 
 
b0e280f
9c68392
b0e280f
9c68392
4cea813
20dcad7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import streamlit as st
import streamlit.components.v1 as components
from PIL import Image
import requests

from predict import generate_text
from model import load_model

from streamlit_image_select import image_select


# Configure Streamlit page
st.set_page_config(page_title="Caption Machine", page_icon="📸")

# Set Session

model, image_transform, tokenizer = load_model()

if 'model' not in st.session_state:
    st.session_state['model'] = model

if 'image_transform' not in st.session_state:
    st.session_state['image_transform'] = image_transform

if 'tokenizer' not in st.session_state:
    st.session_state['tokenizer'] = tokenizer



# Force responsive layout for columns also on mobile
st.write(
    """<style>
    [data-testid="column"] {
        width: calc(50% - 1rem);
        flex: 1 1 calc(50% - 1rem);
        min-width: calc(50% - 1rem);
    }

    .separator {
        display: flex;
        align-items: center;
        text-align: center;
    }

    .separator::before,
    .separator::after {
        content: '';
        flex: 1;
        border-bottom: 1px solid #000;
    }

    .separator:not(:empty)::before {
        margin-right: .25em;
    }

    .separator:not(:empty)::after {
        margin-left: .25em;
    }

    </style>""",
    unsafe_allow_html=True,
)

# Render Streamlit page
st.title("Image Captioner")
st.markdown(
    "This app utilizes OpenAI's [GPT-2](https://openai.com/research/better-language-models) and [CLIP](https://openai.com/research/clip) models to generate image captions. The model architecture was inspired by [ClipCap: CLIP Prefix for Image Captioning](https://arxiv.org/abs/2111.09734), which uses CLIP encoding as prefix and fine-tune GPT-2 model to generate the caption."
)



# Select image or upload image
select_file = image_select(
    label="Select a photo:",
    images=[
        "https://farm5.staticflickr.com/4084/5093294428_2f50d54acb_z.jpg",
        "https://farm8.staticflickr.com/7044/6855243647_cd204d079c_z.jpg",
        "http://farm4.staticflickr.com/3016/2650267987_f478c8d682_z.jpg",
        "https://farm8.staticflickr.com/7249/6913786280_c145ecc433_z.jpg",
    ],
    # captions=["A cat", "Another cat", "Oh look, a cat!", "Guess what, a cat..."],
)

st.markdown("<div class='separator'>Or</div>", unsafe_allow_html=True)


upload_file = st.file_uploader("Upload an image:", type=['png','jpg','jpeg'])


# Checking the Format of the page
if upload_file or select_file:

    img = None

    if upload_file:
        img = Image.open(upload_file)
    
    elif select_file:
        # st.text(select_file)
        img = Image.open(requests.get(select_file, stream=True).raw)
    
    
    st.image(img)
    # st.write("Image Uploaded Successfully")

    # gpt_model, tokenizer = load_gpt_model()
    with st.spinner('Generating caption...'):
        caption = generate_text(st.session_state['model'], img, st.session_state['tokenizer'], st.session_state['image_transform'])

        st.success(f"Result: {caption}")
    

# Model information
with st.expander("See model architecture"):
    st.markdown(
        """
        Steps:
        1.  Feed image into CLIP Image Encoder to get image embedding
        2.  image embedding into text embedding shape
        3.  Feed Text into GPT-2 Text Embedder to get a text embedding
        4.  Concatenate two embeddings and feed into GPT-2 Attention Layers  
        """)
    
    st.write(" \nModel Architecture:  ")
    model_img = Image.open('./model.png')
    st.image(model_img, width=450)