Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- app.py +71 -0
- requirements.txt +86 -0
app.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import tensorflow as tf
|
5 |
+
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
|
6 |
+
from tensorflow.keras.preprocessing.image import img_to_array
|
7 |
+
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
8 |
+
import pickle
|
9 |
+
|
10 |
+
# Load your pre-trained model and tokenizer
|
11 |
+
model = tf.keras.models.load_model("model.h5")
|
12 |
+
with open("tokenizer.pkl", "rb") as handle:
|
13 |
+
tokenizer = pickle.load(handle)
|
14 |
+
|
15 |
+
# Load your precomputed features if required (else comment out)
|
16 |
+
# with open("features.pkl", "rb") as f:
|
17 |
+
# features = pickle.load(f)
|
18 |
+
|
19 |
+
# Image feature extractor model
|
20 |
+
feature_extractor = VGG16()
|
21 |
+
feature_extractor = tf.keras.Model(feature_extractor.input, feature_extractor.layers[-2].output)
|
22 |
+
|
23 |
+
# Description generation function
|
24 |
+
def generate_caption(image):
|
25 |
+
# Preprocess the image
|
26 |
+
image = image.resize((224, 224))
|
27 |
+
image = img_to_array(image)
|
28 |
+
image = np.expand_dims(image, axis=0)
|
29 |
+
image = preprocess_input(image)
|
30 |
+
|
31 |
+
# Extract features
|
32 |
+
feature = feature_extractor.predict(image, verbose=0)
|
33 |
+
|
34 |
+
# Generate caption (mock example: replace with your real inference loop)
|
35 |
+
input_text = 'startseq'
|
36 |
+
max_length = 34 # set this to your model's max_length
|
37 |
+
|
38 |
+
for _ in range(max_length):
|
39 |
+
sequence = tokenizer.texts_to_sequences([input_text])[0]
|
40 |
+
sequence = pad_sequences([sequence], maxlen=max_length)
|
41 |
+
yhat = model.predict([feature, sequence], verbose=0)
|
42 |
+
yhat = np.argmax(yhat)
|
43 |
+
word = ''
|
44 |
+
for w, i in tokenizer.word_index.items():
|
45 |
+
if i == yhat:
|
46 |
+
word = w
|
47 |
+
break
|
48 |
+
if word == 'endseq' or word == '':
|
49 |
+
break
|
50 |
+
input_text += ' ' + word
|
51 |
+
|
52 |
+
caption = input_text.replace('startseq', '').strip()
|
53 |
+
return caption
|
54 |
+
|
55 |
+
# Gradio Interface
|
56 |
+
title = "📸 Image Caption Generator"
|
57 |
+
description = "Upload an image and let the AI generate a descriptive caption for it."
|
58 |
+
theme = "soft"
|
59 |
+
|
60 |
+
iface = gr.Interface(
|
61 |
+
fn=generate_caption,
|
62 |
+
inputs=gr.Image(type="pil"),
|
63 |
+
outputs=gr.Textbox(label="Generated Caption"),
|
64 |
+
title=title,
|
65 |
+
description=description,
|
66 |
+
theme=theme,
|
67 |
+
allow_flagging="never"
|
68 |
+
)
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
iface.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gdown
|
2 |
+
absl-py==1.4.0
|
3 |
+
altair==5.1.0
|
4 |
+
astunparse==1.6.3
|
5 |
+
attrs==23.1.0
|
6 |
+
blinker==1.6.2
|
7 |
+
cachetools==5.3.1
|
8 |
+
certifi==2023.7.22
|
9 |
+
charset-normalizer==3.2.0
|
10 |
+
click==8.1.7
|
11 |
+
colorama==0.4.6
|
12 |
+
contourpy==1.1.0
|
13 |
+
cycler==0.11.0
|
14 |
+
flatbuffers==23.5.26
|
15 |
+
fonttools==4.42.1
|
16 |
+
gast==0.4.0
|
17 |
+
google-auth==2.22.0
|
18 |
+
google-auth-oauthlib==1.0.0
|
19 |
+
google-pasta==0.2.0
|
20 |
+
grpcio==1.57.0
|
21 |
+
h5py==3.9.0
|
22 |
+
idna==3.4
|
23 |
+
importlib-metadata==6.8.0
|
24 |
+
jax==0.4.14
|
25 |
+
Jinja2==3.1.2
|
26 |
+
joblib==1.3.2
|
27 |
+
jsonschema==4.19.0
|
28 |
+
jsonschema-specifications==2023.7.1
|
29 |
+
# Removed keras==2.15.0 since tensorflow includes keras now
|
30 |
+
kiwisolver==1.4.5
|
31 |
+
libclang==16.0.6
|
32 |
+
Markdown==3.4.4
|
33 |
+
markdown-it-py==3.0.0
|
34 |
+
MarkupSafe==2.1.3
|
35 |
+
matplotlib==3.7.2
|
36 |
+
mdurl==0.1.2
|
37 |
+
ml-dtypes==0.2.0
|
38 |
+
nltk==3.8.1
|
39 |
+
numpy==1.23.5
|
40 |
+
oauthlib==3.2.2
|
41 |
+
opt-einsum==3.3.0
|
42 |
+
packaging==23.1
|
43 |
+
pandas==2.0.3
|
44 |
+
Pillow==9.5.0
|
45 |
+
protobuf==4.24.2
|
46 |
+
pyarrow==13.0.0
|
47 |
+
pyasn1==0.5.0
|
48 |
+
pyasn1-modules==0.3.0
|
49 |
+
pydeck==0.8.0
|
50 |
+
Pygments==2.16.1
|
51 |
+
Pympler==1.0.1
|
52 |
+
pyparsing==3.0.9
|
53 |
+
python-dateutil==2.8.2
|
54 |
+
pytz==2023.3
|
55 |
+
pytz-deprecation-shim==0.1.0.post0
|
56 |
+
referencing==0.30.2
|
57 |
+
regex==2023.8.8
|
58 |
+
requests==2.31.0
|
59 |
+
requests-oauthlib==1.3.1
|
60 |
+
rich==13.5.2
|
61 |
+
rpds-py==0.10.0
|
62 |
+
rsa==4.9
|
63 |
+
scipy==1.11.2
|
64 |
+
six==1.16.0
|
65 |
+
smmap==5.0.0
|
66 |
+
streamlit==1.38.0
|
67 |
+
tenacity==8.2.3
|
68 |
+
tensorboard==2.15.2
|
69 |
+
tensorboard-data-server==0.7.1
|
70 |
+
tensorflow==2.15.0
|
71 |
+
tensorflow-estimator==2.15.0
|
72 |
+
tensorflow-io-gcs-filesystem==0.31.0
|
73 |
+
termcolor==2.3.0
|
74 |
+
toml==0.10.2
|
75 |
+
toolz==0.12.0
|
76 |
+
tornado==6.3.3
|
77 |
+
tqdm==4.66.1
|
78 |
+
typing_extensions==4.5.0
|
79 |
+
tzdata==2023.3
|
80 |
+
tzlocal==4.3.1
|
81 |
+
urllib3==1.26.16
|
82 |
+
validators==0.21.2
|
83 |
+
watchdog==3.0.0
|
84 |
+
Werkzeug==2.3.7
|
85 |
+
wrapt==1.14.1
|
86 |
+
zipp==3.16.2
|