caption / app.py
pr0ximaCent's picture
Update app.py
de58874 verified
import gradio as gr
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.preprocessing.sequence import pad_sequences
import pickle
import os
# Check if required files exist
def check_required_files():
required_files = ["caption_model.h5", "tokenizer.pkl"]
missing_files = []
for file in required_files:
if not os.path.exists(file):
missing_files.append(file)
else:
size = os.path.getsize(file)
print(f"βœ“ Found {file} ({size} bytes)")
if missing_files:
print(f"βœ— Missing files: {missing_files}")
return False
return True
print("Checking required files...")
files_exist = check_required_files()
# Custom function to handle attention mechanism
def attention_function(inputs):
"""
Custom attention function that likely combines two inputs
Input 1: (None, 34, 34) - attention weights
Input 2: (None, 34, 512) - feature vectors
Output: (None, 34, 512) - attended features
"""
attention_weights, features = inputs
# Expand attention weights to match feature dimensions
attention_weights = tf.expand_dims(attention_weights, axis=-1)
# Apply attention weights to features
attended_features = attention_weights * features
return attended_features
def attention_output_shape(input_shapes):
"""Define the output shape for attention mechanism"""
# Return the shape of the feature input (second input)
return input_shapes[1] # (None, 34, 512)
# Alternative attention functions to try
def attention_function_v2(inputs):
"""Alternative attention mechanism - weighted sum"""
attention_weights, features = inputs
# Normalize attention weights
attention_weights = tf.nn.softmax(attention_weights, axis=-1)
attention_weights = tf.expand_dims(attention_weights, axis=-1)
return attention_weights * features
def attention_function_v3(inputs):
"""Another alternative - dot product attention"""
attention_weights, features = inputs
# Sum along the second dimension of attention weights
attention_weights = tf.reduce_sum(attention_weights, axis=-1, keepdims=True)
attention_weights = tf.expand_dims(attention_weights, axis=-1)
return attention_weights * features
# Custom Lambda layer class
class AttentionLambda(tf.keras.layers.Lambda):
def __init__(self, function, output_shape_func=None, **kwargs):
super().__init__(function, **kwargs)
self.output_shape_func = output_shape_func
def compute_output_shape(self, input_shape):
if self.output_shape_func:
return self.output_shape_func(input_shape)
# Default: return the shape of the second input (features)
if isinstance(input_shape, list) and len(input_shape) >= 2:
return input_shape[1]
return input_shape
# Define multiple custom objects to try different attention mechanisms
def get_custom_objects(attention_func, output_shape_func):
return {
'Lambda': lambda function=None, **kwargs: AttentionLambda(
attention_func if function is None else function,
output_shape_func,
**kwargs
)
}
# Multiple loading strategies with different attention mechanisms
def load_model_safely():
print("Starting model loading process...")
# Strategy 1: Try with custom Lambda that handles the attention operation
try:
print("Strategy 1: Loading with custom attention Lambda...")
def custom_attention(inputs):
"""Handle attention mechanism between two inputs"""
if len(inputs) == 2:
attention_weights, features = inputs
# Simple attention: multiply attention weights with features
# Expand attention weights to match feature dimensions
if len(attention_weights.shape) == 3 and len(features.shape) == 3:
attention_weights = tf.expand_dims(attention_weights, axis=-1)
return tf.multiply(attention_weights, features)
return inputs[0] if isinstance(inputs, list) else inputs
custom_objects = {
'Lambda': lambda function=None, output_shape=None, **kwargs:
tf.keras.layers.Lambda(
custom_attention if function is None else function,
output_shape=lambda input_shape: input_shape[1] if isinstance(input_shape, list) else input_shape,
**kwargs
)
}
model = tf.keras.models.load_model("caption_model.h5", custom_objects=custom_objects)
print("βœ“ Strategy 1 successful!")
return model
except Exception as e:
print(f"βœ— Strategy 1 failed: {str(e)[:200]}...")
# Strategy 2: Load with compile=False and try to fix compilation later
try:
print("Strategy 2: Loading without compilation...")
model = tf.keras.models.load_model("caption_model.h5", compile=False)
print("βœ“ Strategy 2 successful!")
return model
except Exception as e:
print(f"βœ— Strategy 2 failed: {str(e)[:200]}...")
# Strategy 3: Try loading with TensorFlow's built-in Lambda handling
try:
print("Strategy 3: Loading with default Lambda handling...")
def identity_function(x):
if isinstance(x, list) and len(x) == 2:
# For attention mechanism, return the second input (features)
return x[1]
return x
custom_objects = {
'Lambda': lambda function=identity_function, output_shape=None, **kwargs:
tf.keras.layers.Lambda(
function,
output_shape=lambda input_shape: input_shape[1] if isinstance(input_shape, list) else input_shape,
**kwargs
)
}
model = tf.keras.models.load_model("caption_model.h5", custom_objects=custom_objects)
print("βœ“ Strategy 3 successful!")
return model
except Exception as e:
print(f"βœ— Strategy 3 failed: {str(e)[:200]}...")
# Strategy 4: Try with minimal custom objects
try:
print("Strategy 4: Loading with minimal custom objects...")
model = tf.keras.models.load_model("caption_model.h5", custom_objects={'Lambda': tf.keras.layers.Lambda})
print("βœ“ Strategy 4 successful!")
return model
except Exception as e:
print(f"βœ— Strategy 4 failed: {str(e)[:200]}...")
print("All strategies failed. Model could not be loaded.")
raise Exception("All model loading strategies failed. The model file may be corrupted or incompatible.")
# Load your pre-trained model and tokenizer
if not files_exist:
print("Cannot proceed without required files.")
model = None
tokenizer = None
else:
# Load tokenizer first
try:
with open("tokenizer.pkl", "rb") as handle:
tokenizer = pickle.load(handle)
print("βœ“ Tokenizer loaded successfully")
except Exception as e:
print(f"βœ— Failed to load tokenizer: {e}")
tokenizer = None
# Load model
try:
model = load_model_safely()
print("βœ“ Model loaded successfully and ready for inference!")
except Exception as e:
print(f"βœ— Failed to load model: {e}")
print("The app will not work without a properly loaded model.")
model = None
# Image feature extractor model
feature_extractor = VGG16()
feature_extractor = tf.keras.Model(feature_extractor.input, feature_extractor.layers[-2].output)
# Description generation function
def generate_caption(image):
try:
if model is None:
return "❌ Model failed to load. Please check the model file and console output for details."
if tokenizer is None:
return "❌ Tokenizer failed to load. Please check the tokenizer.pkl file."
# Preprocess the image
image = image.resize((224, 224))
image = img_to_array(image)
image = np.expand_dims(image, axis=0)
image = preprocess_input(image)
# Extract features
print("Extracting image features...")
feature = feature_extractor.predict(image, verbose=0)
print(f"Features extracted, shape: {feature.shape}")
# Generate caption
input_text = 'startseq'
max_length = 34 # set this to your model's max_length
print("Starting caption generation...")
for i in range(max_length):
sequence = tokenizer.texts_to_sequences([input_text])[0]
sequence = pad_sequences([sequence], maxlen=max_length)
try:
print(f"Prediction step {i+1}: input_text = '{input_text}'")
yhat = model.predict([feature, sequence], verbose=0)
yhat = np.argmax(yhat)
print(f"Predicted token index: {yhat}")
except Exception as e:
print(f"Prediction error at step {i+1}: {e}")
return f"❌ Error during prediction: {str(e)}"
word = ''
for w, i in tokenizer.word_index.items():
if i == yhat:
word = w
break
print(f"Predicted word: '{word}'")
if word == 'endseq' or word == '':
break
input_text += ' ' + word
caption = input_text.replace('startseq', '').strip()
print(f"Final caption: '{caption}'")
return f"βœ… {caption}" if caption else "❌ Unable to generate caption"
except Exception as e:
error_msg = f"❌ Error processing image: {str(e)}"
print(error_msg)
return error_msg
# Gradio Interface
title = "πŸ“Έ Image Caption Generator"
description = "Upload an image and let the AI generate a descriptive caption for it."
theme = "soft"
iface = gr.Interface(
fn=generate_caption,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(label="Generated Caption"),
title=title,
description=description,
theme=theme,
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch()