File size: 10,491 Bytes
46e1197
 
 
 
 
 
 
 
de58874
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46e1197
7034074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4145d27
 
7034074
 
 
 
 
 
4145d27
7034074
 
 
 
 
 
 
 
 
4145d27
7034074
4145d27
de58874
4145d27
de58874
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4145d27
de58874
7034074
de58874
7034074
de58874
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7034074
de58874
7034074
de58874
7034074
de58874
7034074
de58874
 
 
 
 
 
 
7034074
de58874
 
4145d27
46e1197
de58874
 
7034074
de58874
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46e1197
 
 
 
 
 
 
4145d27
7034074
de58874
 
 
 
7034074
4145d27
 
 
 
 
 
 
de58874
4145d27
de58874
4145d27
 
 
 
 
de58874
 
4145d27
 
 
 
de58874
4145d27
 
de58874
4145d27
de58874
 
4145d27
 
 
 
 
 
 
de58874
4145d27
46e1197
4145d27
 
 
de58874
 
4145d27
 
de58874
 
 
46e1197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4145d27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import gradio as gr
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.preprocessing.sequence import pad_sequences
import pickle
import os

# Check if required files exist
def check_required_files():
    required_files = ["caption_model.h5", "tokenizer.pkl"]
    missing_files = []
    
    for file in required_files:
        if not os.path.exists(file):
            missing_files.append(file)
        else:
            size = os.path.getsize(file)
            print(f"βœ“ Found {file} ({size} bytes)")
    
    if missing_files:
        print(f"βœ— Missing files: {missing_files}")
        return False
    return True

print("Checking required files...")
files_exist = check_required_files()

# Custom function to handle attention mechanism
def attention_function(inputs):
    """
    Custom attention function that likely combines two inputs
    Input 1: (None, 34, 34) - attention weights
    Input 2: (None, 34, 512) - feature vectors
    Output: (None, 34, 512) - attended features
    """
    attention_weights, features = inputs
    # Expand attention weights to match feature dimensions
    attention_weights = tf.expand_dims(attention_weights, axis=-1)
    # Apply attention weights to features
    attended_features = attention_weights * features
    return attended_features

def attention_output_shape(input_shapes):
    """Define the output shape for attention mechanism"""
    # Return the shape of the feature input (second input)
    return input_shapes[1]  # (None, 34, 512)

# Alternative attention functions to try
def attention_function_v2(inputs):
    """Alternative attention mechanism - weighted sum"""
    attention_weights, features = inputs
    # Normalize attention weights
    attention_weights = tf.nn.softmax(attention_weights, axis=-1)
    attention_weights = tf.expand_dims(attention_weights, axis=-1)
    return attention_weights * features

def attention_function_v3(inputs):
    """Another alternative - dot product attention"""
    attention_weights, features = inputs
    # Sum along the second dimension of attention weights
    attention_weights = tf.reduce_sum(attention_weights, axis=-1, keepdims=True)
    attention_weights = tf.expand_dims(attention_weights, axis=-1)
    return attention_weights * features

# Custom Lambda layer class
class AttentionLambda(tf.keras.layers.Lambda):
    def __init__(self, function, output_shape_func=None, **kwargs):
        super().__init__(function, **kwargs)
        self.output_shape_func = output_shape_func
    
    def compute_output_shape(self, input_shape):
        if self.output_shape_func:
            return self.output_shape_func(input_shape)
        # Default: return the shape of the second input (features)
        if isinstance(input_shape, list) and len(input_shape) >= 2:
            return input_shape[1]
        return input_shape

# Define multiple custom objects to try different attention mechanisms
def get_custom_objects(attention_func, output_shape_func):
    return {
        'Lambda': lambda function=None, **kwargs: AttentionLambda(
            attention_func if function is None else function,
            output_shape_func,
            **kwargs
        )
    }

# Multiple loading strategies with different attention mechanisms
def load_model_safely():
    print("Starting model loading process...")
    
    # Strategy 1: Try with custom Lambda that handles the attention operation
    try:
        print("Strategy 1: Loading with custom attention Lambda...")
        
        def custom_attention(inputs):
            """Handle attention mechanism between two inputs"""
            if len(inputs) == 2:
                attention_weights, features = inputs
                # Simple attention: multiply attention weights with features
                # Expand attention weights to match feature dimensions
                if len(attention_weights.shape) == 3 and len(features.shape) == 3:
                    attention_weights = tf.expand_dims(attention_weights, axis=-1)
                    return tf.multiply(attention_weights, features)
            return inputs[0] if isinstance(inputs, list) else inputs
        
        custom_objects = {
            'Lambda': lambda function=None, output_shape=None, **kwargs: 
                tf.keras.layers.Lambda(
                    custom_attention if function is None else function,
                    output_shape=lambda input_shape: input_shape[1] if isinstance(input_shape, list) else input_shape,
                    **kwargs
                )
        }
        
        model = tf.keras.models.load_model("caption_model.h5", custom_objects=custom_objects)
        print("βœ“ Strategy 1 successful!")
        return model
        
    except Exception as e:
        print(f"βœ— Strategy 1 failed: {str(e)[:200]}...")
    
    # Strategy 2: Load with compile=False and try to fix compilation later
    try:
        print("Strategy 2: Loading without compilation...")
        model = tf.keras.models.load_model("caption_model.h5", compile=False)
        print("βœ“ Strategy 2 successful!")
        return model
        
    except Exception as e:
        print(f"βœ— Strategy 2 failed: {str(e)[:200]}...")
    
    # Strategy 3: Try loading with TensorFlow's built-in Lambda handling
    try:
        print("Strategy 3: Loading with default Lambda handling...")
        
        def identity_function(x):
            if isinstance(x, list) and len(x) == 2:
                # For attention mechanism, return the second input (features)
                return x[1]
            return x
            
        custom_objects = {
            'Lambda': lambda function=identity_function, output_shape=None, **kwargs:
                tf.keras.layers.Lambda(
                    function,
                    output_shape=lambda input_shape: input_shape[1] if isinstance(input_shape, list) else input_shape,
                    **kwargs
                )
        }
        
        model = tf.keras.models.load_model("caption_model.h5", custom_objects=custom_objects)
        print("βœ“ Strategy 3 successful!")
        return model
        
    except Exception as e:
        print(f"βœ— Strategy 3 failed: {str(e)[:200]}...")
    
    # Strategy 4: Try with minimal custom objects
    try:
        print("Strategy 4: Loading with minimal custom objects...")
        model = tf.keras.models.load_model("caption_model.h5", custom_objects={'Lambda': tf.keras.layers.Lambda})
        print("βœ“ Strategy 4 successful!")
        return model
        
    except Exception as e:
        print(f"βœ— Strategy 4 failed: {str(e)[:200]}...")
    
    print("All strategies failed. Model could not be loaded.")
    raise Exception("All model loading strategies failed. The model file may be corrupted or incompatible.")

# Load your pre-trained model and tokenizer
if not files_exist:
    print("Cannot proceed without required files.")
    model = None
    tokenizer = None
else:
    # Load tokenizer first
    try:
        with open("tokenizer.pkl", "rb") as handle:
            tokenizer = pickle.load(handle)
        print("βœ“ Tokenizer loaded successfully")
    except Exception as e:
        print(f"βœ— Failed to load tokenizer: {e}")
        tokenizer = None
    
    # Load model
    try:
        model = load_model_safely()
        print("βœ“ Model loaded successfully and ready for inference!")
    except Exception as e:
        print(f"βœ— Failed to load model: {e}")
        print("The app will not work without a properly loaded model.")
        model = None

# Image feature extractor model
feature_extractor = VGG16()
feature_extractor = tf.keras.Model(feature_extractor.input, feature_extractor.layers[-2].output)

# Description generation function
def generate_caption(image):
    try:
        if model is None:
            return "❌ Model failed to load. Please check the model file and console output for details."
        
        if tokenizer is None:
            return "❌ Tokenizer failed to load. Please check the tokenizer.pkl file."
        
        # Preprocess the image
        image = image.resize((224, 224))
        image = img_to_array(image)
        image = np.expand_dims(image, axis=0)
        image = preprocess_input(image)
        
        # Extract features
        print("Extracting image features...")
        feature = feature_extractor.predict(image, verbose=0)
        print(f"Features extracted, shape: {feature.shape}")
        
        # Generate caption
        input_text = 'startseq'
        max_length = 34  # set this to your model's max_length
        
        print("Starting caption generation...")
        for i in range(max_length):
            sequence = tokenizer.texts_to_sequences([input_text])[0]
            sequence = pad_sequences([sequence], maxlen=max_length)
            
            try:
                print(f"Prediction step {i+1}: input_text = '{input_text}'")
                yhat = model.predict([feature, sequence], verbose=0)
                yhat = np.argmax(yhat)
                print(f"Predicted token index: {yhat}")
            except Exception as e:
                print(f"Prediction error at step {i+1}: {e}")
                return f"❌ Error during prediction: {str(e)}"
            
            word = ''
            for w, i in tokenizer.word_index.items():
                if i == yhat:
                    word = w
                    break
            
            print(f"Predicted word: '{word}'")
            if word == 'endseq' or word == '':
                break
            input_text += ' ' + word
        
        caption = input_text.replace('startseq', '').strip()
        print(f"Final caption: '{caption}'")
        return f"βœ… {caption}" if caption else "❌ Unable to generate caption"
    
    except Exception as e:
        error_msg = f"❌ Error processing image: {str(e)}"
        print(error_msg)
        return error_msg

# Gradio Interface
title = "πŸ“Έ Image Caption Generator"
description = "Upload an image and let the AI generate a descriptive caption for it."
theme = "soft"

iface = gr.Interface(
    fn=generate_caption,
    inputs=gr.Image(type="pil"),
    outputs=gr.Textbox(label="Generated Caption"),
    title=title,
    description=description,
    theme=theme,
    allow_flagging="never"
)

if __name__ == "__main__":
    iface.launch()