pr0ximaCent commited on
Commit
de58874
Β·
verified Β·
1 Parent(s): 7034074

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -41
app.py CHANGED
@@ -6,6 +6,27 @@ from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
6
  from tensorflow.keras.preprocessing.image import img_to_array
7
  from tensorflow.keras.preprocessing.sequence import pad_sequences
8
  import pickle
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Custom function to handle attention mechanism
11
  def attention_function(inputs):
@@ -70,53 +91,111 @@ def get_custom_objects(attention_func, output_shape_func):
70
 
71
  # Multiple loading strategies with different attention mechanisms
72
  def load_model_safely():
73
- attention_strategies = [
74
- (attention_function, attention_output_shape),
75
- (attention_function_v2, attention_output_shape),
76
- (attention_function_v3, attention_output_shape),
77
- ]
78
 
79
- for i, (att_func, shape_func) in enumerate(attention_strategies, 1):
80
- try:
81
- print(f"Trying attention strategy {i}...")
82
- custom_objects = get_custom_objects(att_func, shape_func)
83
- model = tf.keras.models.load_model("caption_model.h5", custom_objects=custom_objects)
84
- print(f"Model loaded successfully using attention strategy {i}!")
85
- return model
86
- except Exception as e:
87
- print(f"Attention strategy {i} failed: {e}")
88
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- # If all attention strategies fail, try loading without compilation
91
  try:
92
- print("Trying to load without compilation...")
93
  model = tf.keras.models.load_model("caption_model.h5", compile=False)
94
- print("Model loaded without compilation!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  return model
 
96
  except Exception as e:
97
- print(f"Loading without compilation failed: {e}")
98
 
99
- # Last resort: try to load and rebuild the model
100
  try:
101
- print("Attempting to load model weights only...")
102
- # This is a more complex approach that would require knowing the model architecture
103
- raise Exception("Model architecture reconstruction needed")
104
- except:
105
- pass
 
 
106
 
107
- raise Exception("All loading strategies failed. The model may need to be retrained or converted.")
 
108
 
109
  # Load your pre-trained model and tokenizer
110
- try:
111
- model = load_model_safely()
112
- except Exception as e:
113
- print(f"Failed to load model: {e}")
114
- print("Creating a dummy model for testing...")
115
- # Create a simple dummy model for testing the interface
116
  model = None
117
-
118
- with open("tokenizer.pkl", "rb") as handle:
119
- tokenizer = pickle.load(handle)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  # Image feature extractor model
122
  feature_extractor = VGG16()
@@ -126,7 +205,10 @@ feature_extractor = tf.keras.Model(feature_extractor.input, feature_extractor.la
126
  def generate_caption(image):
127
  try:
128
  if model is None:
129
- return "Model failed to load. Please check the model file."
 
 
 
130
 
131
  # Preprocess the image
132
  image = image.resize((224, 224))
@@ -135,22 +217,27 @@ def generate_caption(image):
135
  image = preprocess_input(image)
136
 
137
  # Extract features
 
138
  feature = feature_extractor.predict(image, verbose=0)
 
139
 
140
  # Generate caption
141
  input_text = 'startseq'
142
  max_length = 34 # set this to your model's max_length
143
 
144
- for _ in range(max_length):
 
145
  sequence = tokenizer.texts_to_sequences([input_text])[0]
146
  sequence = pad_sequences([sequence], maxlen=max_length)
147
 
148
  try:
 
149
  yhat = model.predict([feature, sequence], verbose=0)
150
  yhat = np.argmax(yhat)
 
151
  except Exception as e:
152
- print(f"Prediction error: {e}")
153
- return f"Error during prediction: {str(e)}"
154
 
155
  word = ''
156
  for w, i in tokenizer.word_index.items():
@@ -158,15 +245,19 @@ def generate_caption(image):
158
  word = w
159
  break
160
 
 
161
  if word == 'endseq' or word == '':
162
  break
163
  input_text += ' ' + word
164
 
165
  caption = input_text.replace('startseq', '').strip()
166
- return caption if caption else "Unable to generate caption"
 
167
 
168
  except Exception as e:
169
- return f"Error processing image: {str(e)}"
 
 
170
 
171
  # Gradio Interface
172
  title = "πŸ“Έ Image Caption Generator"
 
6
  from tensorflow.keras.preprocessing.image import img_to_array
7
  from tensorflow.keras.preprocessing.sequence import pad_sequences
8
  import pickle
9
+ import os
10
+
11
+ # Check if required files exist
12
+ def check_required_files():
13
+ required_files = ["caption_model.h5", "tokenizer.pkl"]
14
+ missing_files = []
15
+
16
+ for file in required_files:
17
+ if not os.path.exists(file):
18
+ missing_files.append(file)
19
+ else:
20
+ size = os.path.getsize(file)
21
+ print(f"βœ“ Found {file} ({size} bytes)")
22
+
23
+ if missing_files:
24
+ print(f"βœ— Missing files: {missing_files}")
25
+ return False
26
+ return True
27
+
28
+ print("Checking required files...")
29
+ files_exist = check_required_files()
30
 
31
  # Custom function to handle attention mechanism
32
  def attention_function(inputs):
 
91
 
92
  # Multiple loading strategies with different attention mechanisms
93
  def load_model_safely():
94
+ print("Starting model loading process...")
 
 
 
 
95
 
96
+ # Strategy 1: Try with custom Lambda that handles the attention operation
97
+ try:
98
+ print("Strategy 1: Loading with custom attention Lambda...")
99
+
100
+ def custom_attention(inputs):
101
+ """Handle attention mechanism between two inputs"""
102
+ if len(inputs) == 2:
103
+ attention_weights, features = inputs
104
+ # Simple attention: multiply attention weights with features
105
+ # Expand attention weights to match feature dimensions
106
+ if len(attention_weights.shape) == 3 and len(features.shape) == 3:
107
+ attention_weights = tf.expand_dims(attention_weights, axis=-1)
108
+ return tf.multiply(attention_weights, features)
109
+ return inputs[0] if isinstance(inputs, list) else inputs
110
+
111
+ custom_objects = {
112
+ 'Lambda': lambda function=None, output_shape=None, **kwargs:
113
+ tf.keras.layers.Lambda(
114
+ custom_attention if function is None else function,
115
+ output_shape=lambda input_shape: input_shape[1] if isinstance(input_shape, list) else input_shape,
116
+ **kwargs
117
+ )
118
+ }
119
+
120
+ model = tf.keras.models.load_model("caption_model.h5", custom_objects=custom_objects)
121
+ print("βœ“ Strategy 1 successful!")
122
+ return model
123
+
124
+ except Exception as e:
125
+ print(f"βœ— Strategy 1 failed: {str(e)[:200]}...")
126
 
127
+ # Strategy 2: Load with compile=False and try to fix compilation later
128
  try:
129
+ print("Strategy 2: Loading without compilation...")
130
  model = tf.keras.models.load_model("caption_model.h5", compile=False)
131
+ print("βœ“ Strategy 2 successful!")
132
+ return model
133
+
134
+ except Exception as e:
135
+ print(f"βœ— Strategy 2 failed: {str(e)[:200]}...")
136
+
137
+ # Strategy 3: Try loading with TensorFlow's built-in Lambda handling
138
+ try:
139
+ print("Strategy 3: Loading with default Lambda handling...")
140
+
141
+ def identity_function(x):
142
+ if isinstance(x, list) and len(x) == 2:
143
+ # For attention mechanism, return the second input (features)
144
+ return x[1]
145
+ return x
146
+
147
+ custom_objects = {
148
+ 'Lambda': lambda function=identity_function, output_shape=None, **kwargs:
149
+ tf.keras.layers.Lambda(
150
+ function,
151
+ output_shape=lambda input_shape: input_shape[1] if isinstance(input_shape, list) else input_shape,
152
+ **kwargs
153
+ )
154
+ }
155
+
156
+ model = tf.keras.models.load_model("caption_model.h5", custom_objects=custom_objects)
157
+ print("βœ“ Strategy 3 successful!")
158
  return model
159
+
160
  except Exception as e:
161
+ print(f"βœ— Strategy 3 failed: {str(e)[:200]}...")
162
 
163
+ # Strategy 4: Try with minimal custom objects
164
  try:
165
+ print("Strategy 4: Loading with minimal custom objects...")
166
+ model = tf.keras.models.load_model("caption_model.h5", custom_objects={'Lambda': tf.keras.layers.Lambda})
167
+ print("βœ“ Strategy 4 successful!")
168
+ return model
169
+
170
+ except Exception as e:
171
+ print(f"βœ— Strategy 4 failed: {str(e)[:200]}...")
172
 
173
+ print("All strategies failed. Model could not be loaded.")
174
+ raise Exception("All model loading strategies failed. The model file may be corrupted or incompatible.")
175
 
176
  # Load your pre-trained model and tokenizer
177
+ if not files_exist:
178
+ print("Cannot proceed without required files.")
 
 
 
 
179
  model = None
180
+ tokenizer = None
181
+ else:
182
+ # Load tokenizer first
183
+ try:
184
+ with open("tokenizer.pkl", "rb") as handle:
185
+ tokenizer = pickle.load(handle)
186
+ print("βœ“ Tokenizer loaded successfully")
187
+ except Exception as e:
188
+ print(f"βœ— Failed to load tokenizer: {e}")
189
+ tokenizer = None
190
+
191
+ # Load model
192
+ try:
193
+ model = load_model_safely()
194
+ print("βœ“ Model loaded successfully and ready for inference!")
195
+ except Exception as e:
196
+ print(f"βœ— Failed to load model: {e}")
197
+ print("The app will not work without a properly loaded model.")
198
+ model = None
199
 
200
  # Image feature extractor model
201
  feature_extractor = VGG16()
 
205
  def generate_caption(image):
206
  try:
207
  if model is None:
208
+ return "❌ Model failed to load. Please check the model file and console output for details."
209
+
210
+ if tokenizer is None:
211
+ return "❌ Tokenizer failed to load. Please check the tokenizer.pkl file."
212
 
213
  # Preprocess the image
214
  image = image.resize((224, 224))
 
217
  image = preprocess_input(image)
218
 
219
  # Extract features
220
+ print("Extracting image features...")
221
  feature = feature_extractor.predict(image, verbose=0)
222
+ print(f"Features extracted, shape: {feature.shape}")
223
 
224
  # Generate caption
225
  input_text = 'startseq'
226
  max_length = 34 # set this to your model's max_length
227
 
228
+ print("Starting caption generation...")
229
+ for i in range(max_length):
230
  sequence = tokenizer.texts_to_sequences([input_text])[0]
231
  sequence = pad_sequences([sequence], maxlen=max_length)
232
 
233
  try:
234
+ print(f"Prediction step {i+1}: input_text = '{input_text}'")
235
  yhat = model.predict([feature, sequence], verbose=0)
236
  yhat = np.argmax(yhat)
237
+ print(f"Predicted token index: {yhat}")
238
  except Exception as e:
239
+ print(f"Prediction error at step {i+1}: {e}")
240
+ return f"❌ Error during prediction: {str(e)}"
241
 
242
  word = ''
243
  for w, i in tokenizer.word_index.items():
 
245
  word = w
246
  break
247
 
248
+ print(f"Predicted word: '{word}'")
249
  if word == 'endseq' or word == '':
250
  break
251
  input_text += ' ' + word
252
 
253
  caption = input_text.replace('startseq', '').strip()
254
+ print(f"Final caption: '{caption}'")
255
+ return f"βœ… {caption}" if caption else "❌ Unable to generate caption"
256
 
257
  except Exception as e:
258
+ error_msg = f"❌ Error processing image: {str(e)}"
259
+ print(error_msg)
260
+ return error_msg
261
 
262
  # Gradio Interface
263
  title = "πŸ“Έ Image Caption Generator"