pr0ximaCent commited on
Commit
7034074
·
verified ·
1 Parent(s): 4145d27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -32
app.py CHANGED
@@ -7,50 +7,113 @@ from tensorflow.keras.preprocessing.image import img_to_array
7
  from tensorflow.keras.preprocessing.sequence import pad_sequences
8
  import pickle
9
 
10
- # Custom Lambda layer with explicit output shape
11
- class CustomLambda(tf.keras.layers.Lambda):
12
- def __init__(self, function, output_shape=None, **kwargs):
13
- super().__init__(function, output_shape=output_shape, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def compute_output_shape(self, input_shape):
16
- if self.output_shape is None:
17
- # Default behavior for attention-like operations
18
- if isinstance(input_shape, list) and len(input_shape) == 2:
19
- return input_shape[0] # Return shape of first input
20
- return input_shape
21
- return super().compute_output_shape(input_shape)
22
 
23
- # Define custom objects for model loading
24
- custom_objects = {
25
- 'Lambda': CustomLambda,
26
- 'lambda': CustomLambda
27
- }
 
 
 
 
28
 
29
- # Multiple loading strategies
30
  def load_model_safely():
31
- strategies = [
32
- # Strategy 1: Load with custom objects
33
- lambda: tf.keras.models.load_model("caption_model.h5", custom_objects=custom_objects),
34
- # Strategy 2: Load without compilation
35
- lambda: tf.keras.models.load_model("caption_model.h5", compile=False),
36
- # Strategy 3: Load with different custom objects
37
- lambda: tf.keras.models.load_model("caption_model.h5",
38
- custom_objects={'Lambda': tf.keras.layers.Lambda}),
39
  ]
40
 
41
- for i, strategy in enumerate(strategies, 1):
42
  try:
43
- model = strategy()
44
- print(f"Model loaded successfully using strategy {i}!")
 
 
45
  return model
46
  except Exception as e:
47
- print(f"Strategy {i} failed: {e}")
48
  continue
49
 
50
- raise Exception("All loading strategies failed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # Load your pre-trained model and tokenizer
53
- model = load_model_safely()
 
 
 
 
 
 
54
 
55
  with open("tokenizer.pkl", "rb") as handle:
56
  tokenizer = pickle.load(handle)
@@ -62,6 +125,9 @@ feature_extractor = tf.keras.Model(feature_extractor.input, feature_extractor.la
62
  # Description generation function
63
  def generate_caption(image):
64
  try:
 
 
 
65
  # Preprocess the image
66
  image = image.resize((224, 224))
67
  image = img_to_array(image)
@@ -84,7 +150,7 @@ def generate_caption(image):
84
  yhat = np.argmax(yhat)
85
  except Exception as e:
86
  print(f"Prediction error: {e}")
87
- return "Error generating caption"
88
 
89
  word = ''
90
  for w, i in tokenizer.word_index.items():
@@ -97,7 +163,7 @@ def generate_caption(image):
97
  input_text += ' ' + word
98
 
99
  caption = input_text.replace('startseq', '').strip()
100
- return caption
101
 
102
  except Exception as e:
103
  return f"Error processing image: {str(e)}"
 
7
  from tensorflow.keras.preprocessing.sequence import pad_sequences
8
  import pickle
9
 
10
+ # Custom function to handle attention mechanism
11
+ def attention_function(inputs):
12
+ """
13
+ Custom attention function that likely combines two inputs
14
+ Input 1: (None, 34, 34) - attention weights
15
+ Input 2: (None, 34, 512) - feature vectors
16
+ Output: (None, 34, 512) - attended features
17
+ """
18
+ attention_weights, features = inputs
19
+ # Expand attention weights to match feature dimensions
20
+ attention_weights = tf.expand_dims(attention_weights, axis=-1)
21
+ # Apply attention weights to features
22
+ attended_features = attention_weights * features
23
+ return attended_features
24
+
25
+ def attention_output_shape(input_shapes):
26
+ """Define the output shape for attention mechanism"""
27
+ # Return the shape of the feature input (second input)
28
+ return input_shapes[1] # (None, 34, 512)
29
+
30
+ # Alternative attention functions to try
31
+ def attention_function_v2(inputs):
32
+ """Alternative attention mechanism - weighted sum"""
33
+ attention_weights, features = inputs
34
+ # Normalize attention weights
35
+ attention_weights = tf.nn.softmax(attention_weights, axis=-1)
36
+ attention_weights = tf.expand_dims(attention_weights, axis=-1)
37
+ return attention_weights * features
38
+
39
+ def attention_function_v3(inputs):
40
+ """Another alternative - dot product attention"""
41
+ attention_weights, features = inputs
42
+ # Sum along the second dimension of attention weights
43
+ attention_weights = tf.reduce_sum(attention_weights, axis=-1, keepdims=True)
44
+ attention_weights = tf.expand_dims(attention_weights, axis=-1)
45
+ return attention_weights * features
46
+
47
+ # Custom Lambda layer class
48
+ class AttentionLambda(tf.keras.layers.Lambda):
49
+ def __init__(self, function, output_shape_func=None, **kwargs):
50
+ super().__init__(function, **kwargs)
51
+ self.output_shape_func = output_shape_func
52
 
53
  def compute_output_shape(self, input_shape):
54
+ if self.output_shape_func:
55
+ return self.output_shape_func(input_shape)
56
+ # Default: return the shape of the second input (features)
57
+ if isinstance(input_shape, list) and len(input_shape) >= 2:
58
+ return input_shape[1]
59
+ return input_shape
60
 
61
+ # Define multiple custom objects to try different attention mechanisms
62
+ def get_custom_objects(attention_func, output_shape_func):
63
+ return {
64
+ 'Lambda': lambda function=None, **kwargs: AttentionLambda(
65
+ attention_func if function is None else function,
66
+ output_shape_func,
67
+ **kwargs
68
+ )
69
+ }
70
 
71
+ # Multiple loading strategies with different attention mechanisms
72
  def load_model_safely():
73
+ attention_strategies = [
74
+ (attention_function, attention_output_shape),
75
+ (attention_function_v2, attention_output_shape),
76
+ (attention_function_v3, attention_output_shape),
 
 
 
 
77
  ]
78
 
79
+ for i, (att_func, shape_func) in enumerate(attention_strategies, 1):
80
  try:
81
+ print(f"Trying attention strategy {i}...")
82
+ custom_objects = get_custom_objects(att_func, shape_func)
83
+ model = tf.keras.models.load_model("caption_model.h5", custom_objects=custom_objects)
84
+ print(f"Model loaded successfully using attention strategy {i}!")
85
  return model
86
  except Exception as e:
87
+ print(f"Attention strategy {i} failed: {e}")
88
  continue
89
 
90
+ # If all attention strategies fail, try loading without compilation
91
+ try:
92
+ print("Trying to load without compilation...")
93
+ model = tf.keras.models.load_model("caption_model.h5", compile=False)
94
+ print("Model loaded without compilation!")
95
+ return model
96
+ except Exception as e:
97
+ print(f"Loading without compilation failed: {e}")
98
+
99
+ # Last resort: try to load and rebuild the model
100
+ try:
101
+ print("Attempting to load model weights only...")
102
+ # This is a more complex approach that would require knowing the model architecture
103
+ raise Exception("Model architecture reconstruction needed")
104
+ except:
105
+ pass
106
+
107
+ raise Exception("All loading strategies failed. The model may need to be retrained or converted.")
108
 
109
  # Load your pre-trained model and tokenizer
110
+ try:
111
+ model = load_model_safely()
112
+ except Exception as e:
113
+ print(f"Failed to load model: {e}")
114
+ print("Creating a dummy model for testing...")
115
+ # Create a simple dummy model for testing the interface
116
+ model = None
117
 
118
  with open("tokenizer.pkl", "rb") as handle:
119
  tokenizer = pickle.load(handle)
 
125
  # Description generation function
126
  def generate_caption(image):
127
  try:
128
+ if model is None:
129
+ return "Model failed to load. Please check the model file."
130
+
131
  # Preprocess the image
132
  image = image.resize((224, 224))
133
  image = img_to_array(image)
 
150
  yhat = np.argmax(yhat)
151
  except Exception as e:
152
  print(f"Prediction error: {e}")
153
+ return f"Error during prediction: {str(e)}"
154
 
155
  word = ''
156
  for w, i in tokenizer.word_index.items():
 
163
  input_text += ' ' + word
164
 
165
  caption = input_text.replace('startseq', '').strip()
166
+ return caption if caption else "Unable to generate caption"
167
 
168
  except Exception as e:
169
  return f"Error processing image: {str(e)}"