pr0ximaCent commited on
Commit
7839460
·
verified ·
1 Parent(s): 7bece97

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +71 -0
  2. requirements.txt +86 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
6
+ from tensorflow.keras.preprocessing.image import img_to_array
7
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
8
+ import pickle
9
+
10
+ # Load your pre-trained model and tokenizer
11
+ model = tf.keras.models.load_model("model.h5")
12
+ with open("tokenizer.pkl", "rb") as handle:
13
+ tokenizer = pickle.load(handle)
14
+
15
+ # Load your precomputed features if required (else comment out)
16
+ # with open("features.pkl", "rb") as f:
17
+ # features = pickle.load(f)
18
+
19
+ # Image feature extractor model
20
+ feature_extractor = VGG16()
21
+ feature_extractor = tf.keras.Model(feature_extractor.input, feature_extractor.layers[-2].output)
22
+
23
+ # Description generation function
24
+ def generate_caption(image):
25
+ # Preprocess the image
26
+ image = image.resize((224, 224))
27
+ image = img_to_array(image)
28
+ image = np.expand_dims(image, axis=0)
29
+ image = preprocess_input(image)
30
+
31
+ # Extract features
32
+ feature = feature_extractor.predict(image, verbose=0)
33
+
34
+ # Generate caption (mock example: replace with your real inference loop)
35
+ input_text = 'startseq'
36
+ max_length = 34 # set this to your model's max_length
37
+
38
+ for _ in range(max_length):
39
+ sequence = tokenizer.texts_to_sequences([input_text])[0]
40
+ sequence = pad_sequences([sequence], maxlen=max_length)
41
+ yhat = model.predict([feature, sequence], verbose=0)
42
+ yhat = np.argmax(yhat)
43
+ word = ''
44
+ for w, i in tokenizer.word_index.items():
45
+ if i == yhat:
46
+ word = w
47
+ break
48
+ if word == 'endseq' or word == '':
49
+ break
50
+ input_text += ' ' + word
51
+
52
+ caption = input_text.replace('startseq', '').strip()
53
+ return caption
54
+
55
+ # Gradio Interface
56
+ title = "📸 Image Caption Generator"
57
+ description = "Upload an image and let the AI generate a descriptive caption for it."
58
+ theme = "soft"
59
+
60
+ iface = gr.Interface(
61
+ fn=generate_caption,
62
+ inputs=gr.Image(type="pil"),
63
+ outputs=gr.Textbox(label="Generated Caption"),
64
+ title=title,
65
+ description=description,
66
+ theme=theme,
67
+ allow_flagging="never"
68
+ )
69
+
70
+ if __name__ == "__main__":
71
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gdown
2
+ absl-py==1.4.0
3
+ altair==5.1.0
4
+ astunparse==1.6.3
5
+ attrs==23.1.0
6
+ blinker==1.6.2
7
+ cachetools==5.3.1
8
+ certifi==2023.7.22
9
+ charset-normalizer==3.2.0
10
+ click==8.1.7
11
+ colorama==0.4.6
12
+ contourpy==1.1.0
13
+ cycler==0.11.0
14
+ flatbuffers==23.5.26
15
+ fonttools==4.42.1
16
+ gast==0.4.0
17
+ google-auth==2.22.0
18
+ google-auth-oauthlib==1.0.0
19
+ google-pasta==0.2.0
20
+ grpcio==1.57.0
21
+ h5py==3.9.0
22
+ idna==3.4
23
+ importlib-metadata==6.8.0
24
+ jax==0.4.14
25
+ Jinja2==3.1.2
26
+ joblib==1.3.2
27
+ jsonschema==4.19.0
28
+ jsonschema-specifications==2023.7.1
29
+ # Removed keras==2.15.0 since tensorflow includes keras now
30
+ kiwisolver==1.4.5
31
+ libclang==16.0.6
32
+ Markdown==3.4.4
33
+ markdown-it-py==3.0.0
34
+ MarkupSafe==2.1.3
35
+ matplotlib==3.7.2
36
+ mdurl==0.1.2
37
+ ml-dtypes==0.2.0
38
+ nltk==3.8.1
39
+ numpy==1.23.5
40
+ oauthlib==3.2.2
41
+ opt-einsum==3.3.0
42
+ packaging==23.1
43
+ pandas==2.0.3
44
+ Pillow==9.5.0
45
+ protobuf==4.24.2
46
+ pyarrow==13.0.0
47
+ pyasn1==0.5.0
48
+ pyasn1-modules==0.3.0
49
+ pydeck==0.8.0
50
+ Pygments==2.16.1
51
+ Pympler==1.0.1
52
+ pyparsing==3.0.9
53
+ python-dateutil==2.8.2
54
+ pytz==2023.3
55
+ pytz-deprecation-shim==0.1.0.post0
56
+ referencing==0.30.2
57
+ regex==2023.8.8
58
+ requests==2.31.0
59
+ requests-oauthlib==1.3.1
60
+ rich==13.5.2
61
+ rpds-py==0.10.0
62
+ rsa==4.9
63
+ scipy==1.11.2
64
+ six==1.16.0
65
+ smmap==5.0.0
66
+ streamlit==1.38.0
67
+ tenacity==8.2.3
68
+ tensorboard==2.15.2
69
+ tensorboard-data-server==0.7.1
70
+ tensorflow==2.15.0
71
+ tensorflow-estimator==2.15.0
72
+ tensorflow-io-gcs-filesystem==0.31.0
73
+ termcolor==2.3.0
74
+ toml==0.10.2
75
+ toolz==0.12.0
76
+ tornado==6.3.3
77
+ tqdm==4.66.1
78
+ typing_extensions==4.5.0
79
+ tzdata==2023.3
80
+ tzlocal==4.3.1
81
+ urllib3==1.26.16
82
+ validators==0.21.2
83
+ watchdog==3.0.0
84
+ Werkzeug==2.3.7
85
+ wrapt==1.14.1
86
+ zipp==3.16.2