File size: 3,525 Bytes
20dcad7 b0e280f 20dcad7 4cea813 b0e280f 4cea813 20dcad7 b0e280f 20dcad7 d14c041 20dcad7 9c68392 20dcad7 9c68392 20dcad7 b0e280f 9c68392 b0e280f 20dcad7 b0e280f 9c68392 b0e280f 20dcad7 b0e280f 20dcad7 4cea813 b0e280f 9c68392 b0e280f 9c68392 b0e280f 9c68392 4cea813 20dcad7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import streamlit as st
import streamlit.components.v1 as components
from PIL import Image
import requests
from predict import generate_text
from model import load_model
from streamlit_image_select import image_select
# Configure Streamlit page
st.set_page_config(page_title="Caption Machine", page_icon="📸")
# Set Session
model, image_transform, tokenizer = load_model()
if 'model' not in st.session_state:
st.session_state['model'] = model
if 'image_transform' not in st.session_state:
st.session_state['image_transform'] = image_transform
if 'tokenizer' not in st.session_state:
st.session_state['tokenizer'] = tokenizer
# Force responsive layout for columns also on mobile
st.write(
"""<style>
[data-testid="column"] {
width: calc(50% - 1rem);
flex: 1 1 calc(50% - 1rem);
min-width: calc(50% - 1rem);
}
.separator {
display: flex;
align-items: center;
text-align: center;
}
.separator::before,
.separator::after {
content: '';
flex: 1;
border-bottom: 1px solid #000;
}
.separator:not(:empty)::before {
margin-right: .25em;
}
.separator:not(:empty)::after {
margin-left: .25em;
}
</style>""",
unsafe_allow_html=True,
)
# Render Streamlit page
st.title("Image Captioner")
st.markdown(
"This app utilizes OpenAI's [GPT-2](https://openai.com/research/better-language-models) and [CLIP](https://openai.com/research/clip) models to generate image captions. The model architecture was inspired by [ClipCap: CLIP Prefix for Image Captioning](https://arxiv.org/abs/2111.09734), which uses CLIP encoding as prefix and fine-tune GPT-2 model to generate the caption."
)
# Select image or upload image
select_file = image_select(
label="Select a photo:",
images=[
"https://farm5.staticflickr.com/4084/5093294428_2f50d54acb_z.jpg",
"https://farm8.staticflickr.com/7044/6855243647_cd204d079c_z.jpg",
"http://farm4.staticflickr.com/3016/2650267987_f478c8d682_z.jpg",
"https://farm8.staticflickr.com/7249/6913786280_c145ecc433_z.jpg",
],
# captions=["A cat", "Another cat", "Oh look, a cat!", "Guess what, a cat..."],
)
st.markdown("<div class='separator'>Or</div>", unsafe_allow_html=True)
upload_file = st.file_uploader("Upload an image:", type=['png','jpg','jpeg'])
# Checking the Format of the page
if upload_file or select_file:
img = None
if upload_file:
img = Image.open(upload_file)
elif select_file:
# st.text(select_file)
img = Image.open(requests.get(select_file, stream=True).raw)
st.image(img)
# st.write("Image Uploaded Successfully")
# gpt_model, tokenizer = load_gpt_model()
with st.spinner('Generating caption...'):
caption = generate_text(st.session_state['model'], img, st.session_state['tokenizer'], st.session_state['image_transform'])
st.success(f"Result: {caption}")
# Model information
with st.expander("See model architecture"):
st.markdown(
"""
Steps:
1. Feed image into CLIP Image Encoder to get image embedding
2. image embedding into text embedding shape
3. Feed Text into GPT-2 Text Embedder to get a text embedding
4. Concatenate two embeddings and feed into GPT-2 Attention Layers
""")
st.write(" \nModel Architecture: ")
model_img = Image.open('./model.png')
st.image(model_img, width=450)
|