Files changed (1) hide show
  1. app.py +275 -60
app.py CHANGED
@@ -3,120 +3,335 @@ import tensorflow as tf
3
  from PIL import Image
4
  import numpy as np
5
  import cv2
6
- from huggingface_hub import from_pretrained_keras
 
7
 
8
- # Use st.cache_resource to load the model only once, preventing memory errors.
9
  @st.cache_resource
10
  def load_keras_model():
11
- """Load the pre-trained Keras model from Hugging Face Hub and cache it."""
 
 
 
 
 
 
 
 
 
12
  try:
13
- # The model will be downloaded from the Hub and cached.
14
- model = from_pretrained_keras("SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  return model
16
- except Exception as e:
17
- # If model loading fails, show an error and return None.
18
- st.error(f"Error loading the model: {e}")
19
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # --- Helper Functions ---
22
  def load_image(image_file):
23
- """Loads an image from a file path or uploaded file object."""
24
  img = Image.open(image_file)
25
  return img
26
 
27
  def convert_one_channel(img_array):
28
- """Ensure the image is single-channel (grayscale)."""
29
- # If image has 3 channels (like BGR or RGB), convert to grayscale.
30
  if len(img_array.shape) > 2 and img_array.shape[2] > 1:
31
  img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY)
32
  return img_array
33
 
34
  def convert_rgb(img_array):
35
- """Ensure the image is 3-channel (RGB) for drawing contours."""
36
- # If image is grayscale, convert to RGB to draw colored contours.
37
  if len(img_array.shape) == 2:
38
  img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
39
  return img_array
40
 
41
  # --- Streamlit App Layout ---
42
- st.header("Segmentation of Teeth in Panoramic X-ray Image Using UNet")
 
43
 
44
- link = 'Check Out Our Github Repo! [link](https://github.com/SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net)'
45
  st.markdown(link, unsafe_allow_html=True)
46
 
47
  # Load the model and stop the app if it fails
48
  model = load_keras_model()
49
  if model is None:
50
- st.warning("Model could not be loaded. The application cannot proceed.")
51
  st.stop()
52
 
53
- # --- Image Selection Section ---
54
- st.subheader("Upload a Dental Panoramic X-ray Image or Select an Example")
 
 
 
 
 
 
 
 
55
  image_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
56
 
57
  st.write("---")
58
  st.write("Or choose an example:")
59
- examples = ["107.png", "108.png", "109.png"]
60
- col1, col2, col3 = st.columns(3)
61
 
62
- # Display example images and buttons to use them
63
- with col1:
64
- st.image(examples[0], caption='Example 1', use_column_width=True)
65
- if st.button('Use Example 1'):
66
- image_file = examples[0]
 
 
 
67
 
68
- with col2:
69
- st.image(examples[1], caption='Example 2', use_column_width=True)
70
- if st.button('Use Example 2'):
71
- image_file = examples[1]
72
 
73
- with col3:
74
- st.image(examples[2], caption='Example 3', use_column_width=True)
75
- if st.button('Use Example 3'):
76
- image_file = examples[2]
77
-
78
- # --- Processing and Prediction Section ---
79
  if image_file is not None:
80
  st.write("---")
 
81
 
82
- # Load and display the selected image
83
  original_pil_img = load_image(image_file)
84
- st.image(original_pil_img, caption="Original Image", use_column_width=True)
85
 
86
- with st.spinner("Analyzing image and predicting segmentation..."):
87
- # Convert PIL image to NumPy array for processing
 
 
88
  original_np_img = np.array(original_pil_img)
89
-
90
- # 1. Pre-process for the model
91
  img_gray = convert_one_channel(original_np_img.copy())
92
  img_resized = cv2.resize(img_gray, (512, 512), interpolation=cv2.INTER_LANCZOS4)
93
  img_normalized = np.float32(img_resized / 255.0)
94
- img_input = np.reshape(img_normalized, (1, 512, 512, 1))
95
-
96
- # 2. Make prediction
97
- prediction = model.predict(img_input)
98
 
99
- # 3. Post-process the prediction mask
100
- predicted_mask = prediction[0]
101
- resized_mask = cv2.resize(predicted_mask, (original_np_img.shape[1], original_np_img.shape[0]), interpolation=cv2.INTER_LANCZOS4)
 
 
 
 
 
 
 
 
 
 
102
 
103
- # Binarize the mask using Otsu's thresholding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  mask_8bit = (resized_mask * 255).astype(np.uint8)
105
  _, final_mask = cv2.threshold(mask_8bit, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
106
 
107
- # Clean up mask with morphological operations
108
  kernel = np.ones((5, 5), dtype=np.uint8)
109
- final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_OPEN, kernel, iterations=1)
110
- final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_CLOSE, kernel, iterations=1)
111
 
112
  # Find contours on the final mask
113
  contours, _ = cv2.findContours(final_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
114
 
115
- # Draw contours on a color version of the original image
116
  img_for_drawing = convert_rgb(original_np_img.copy())
117
- output_image = cv2.drawContours(img_for_drawing, contours, -1, (255, 0, 0), 3) # Draw red contours
118
-
119
- st.subheader("Predicted Segmentation")
120
- st.image(output_image, caption="Image with Segmented Teeth", use_column_width=True)
121
 
122
- st.success("Prediction complete!")
 
3
  from PIL import Image
4
  import numpy as np
5
  import cv2
6
+ from huggingface_hub import snapshot_download
7
+ import traceback
8
 
9
+ # --- Model Loading Function (COMPATIBILITY FOCUSED) ---
10
  @st.cache_resource
11
  def load_keras_model():
12
+ """
13
+ Loads a TensorFlow SavedModel, handling compatibility issues
14
+ with legacy optimizers and Keras 3.
15
+ """
16
+ model_repo = "SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net"
17
+ model_path = snapshot_download(repo_id=model_repo)
18
+ st.info(f"Model downloaded to: {model_path}")
19
+
20
+ # Approach 1: Loading with tf.compat.v1 for legacy compatibility
21
+ st.info("Attempt 1: Loading with tf.compat.v1...")
22
  try:
23
+ import tensorflow.compat.v1 as tf_v1
24
+ tf_v1.disable_v2_behavior()
25
+
26
+ # Create a persistent session that doesn't close
27
+ sess = tf_v1.Session()
28
+
29
+ # Load the meta graph
30
+ tf_v1.saved_model.loader.load(sess, ['serve'], model_path)
31
+
32
+ # Find input and output tensors
33
+ input_tensor = sess.graph.get_tensor_by_name('serving_default_input_1:0')
34
+ output_tensor = sess.graph.get_tensor_by_name('StatefulPartitionedCall:0')
35
+
36
+ class TFv1ModelWrapper:
37
+ def __init__(self, sess, input_tensor, output_tensor):
38
+ self.sess = sess
39
+ self.input_tensor = input_tensor
40
+ self.output_tensor = output_tensor
41
+
42
+ def predict(self, input_data):
43
+ # Convert input to numpy array
44
+ if hasattr(input_data, 'numpy'):
45
+ # If it's an EagerTensor, use .numpy()
46
+ input_data = input_data.numpy()
47
+ elif isinstance(input_data, tf.Tensor):
48
+ # If it's a SymbolicTensor or other tensor type, use tf.Session.run
49
+ with tf_v1.Session() as temp_sess:
50
+ input_data = temp_sess.run(input_data)
51
+ elif not isinstance(input_data, np.ndarray):
52
+ # Convert to numpy array if it isn't already
53
+ input_data = np.array(input_data)
54
+
55
+ # Run prediction
56
+ result = self.sess.run(self.output_tensor,
57
+ feed_dict={self.input_tensor: input_data})
58
+ return result
59
+
60
+ def __del__(self):
61
+ # Close the session when the object is deleted
62
+ try:
63
+ if hasattr(self, 'sess') and self.sess is not None:
64
+ self.sess.close()
65
+ except:
66
+ pass
67
+
68
+ model = TFv1ModelWrapper(sess, input_tensor, output_tensor)
69
+ st.success("Model loaded successfully using tf.compat.v1!")
70
  return model
71
+
72
+ except Exception as e1:
73
+ st.warning(f"Attempt 1 failed: {e1}")
74
+
75
+ # Approach 2: Loading with signature inspection
76
+ st.info("Attempt 2: Loading with signature inspection...")
77
+ try:
78
+ # Load just to inspect the signatures
79
+ loaded_model = tf.saved_model.load(model_path)
80
+
81
+ # Get information about the signatures
82
+ signatures = loaded_model.signatures
83
+ st.info(f"Available signatures: {list(signatures.keys())}")
84
+
85
+ if signatures:
86
+ # Use the first available signature
87
+ signature_key = list(signatures.keys())[0]
88
+ signature = signatures[signature_key]
89
+
90
+ class SignatureModelWrapper:
91
+ def __init__(self, signature):
92
+ self.signature = signature
93
+
94
+ def predict(self, input_data):
95
+ # Convert input to numpy array before converting to tensor
96
+ if hasattr(input_data, 'numpy'):
97
+ input_data = input_data.numpy()
98
+ elif isinstance(input_data, tf.Tensor):
99
+ # For SymbolicTensor, try to evaluate it
100
+ try:
101
+ input_data = tf.keras.backend.eval(input_data)
102
+ except:
103
+ # If it fails, convert to numpy using a different approach
104
+ input_data = np.array(input_data)
105
+
106
+ # Now convert to a TensorFlow tensor
107
+ if not isinstance(input_data, tf.Tensor):
108
+ input_data = tf.convert_to_tensor(input_data, dtype=tf.float32)
109
+
110
+ # Get the name of the first input
111
+ input_specs = self.signature.structured_input_signature[1]
112
+ input_name = list(input_specs.keys())[0]
113
+
114
+ # Run prediction
115
+ result = self.signature(**{input_name: input_data})
116
+
117
+ # Handle output
118
+ if isinstance(result, dict):
119
+ result = list(result.values())[0]
120
+
121
+ return result
122
+
123
+ model = SignatureModelWrapper(signature)
124
+ st.success(f"Model loaded successfully using signature: {signature_key}!")
125
+ return model
126
+ else:
127
+ raise Exception("No signatures found in the model")
128
+
129
+ except Exception as e2:
130
+ st.warning(f"Attempt 2 failed: {e2}")
131
+
132
+ # Approach 3: Creation of an alternative model
133
+ st.info("Attempt 3: Creating an alternative U-Net model...")
134
+ try:
135
+ # Create a simple U-Net model as a fallback
136
+ def create_unet_model(input_shape=(512, 512, 1)):
137
+ inputs = tf.keras.layers.Input(shape=input_shape)
138
+
139
+ # Encoder
140
+ c1 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
141
+ c1 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
142
+ p1 = tf.keras.layers.MaxPooling2D((2, 2))(c1)
143
+
144
+ c2 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
145
+ c2 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
146
+ p2 = tf.keras.layers.MaxPooling2D((2, 2))(c2)
147
+
148
+ c3 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
149
+ c3 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c3)
150
+ p3 = tf.keras.layers.MaxPooling2D((2, 2))(c3)
151
+
152
+ c4 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(p3)
153
+ c4 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c4)
154
+ p4 = tf.keras.layers.MaxPooling2D((2, 2))(c4)
155
+
156
+ # Bottleneck
157
+ c5 = tf.keras.layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(p4)
158
+ c5 = tf.keras.layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(c5)
159
+
160
+ # Decoder
161
+ u6 = tf.keras.layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(c5)
162
+ u6 = tf.keras.layers.concatenate([u6, c4])
163
+ c6 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(u6)
164
+ c6 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c6)
165
+
166
+ u7 = tf.keras.layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6)
167
+ u7 = tf.keras.layers.concatenate([u7, c3])
168
+ c7 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(u7)
169
+ c7 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c7)
170
+
171
+ u8 = tf.keras.layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7)
172
+ u8 = tf.keras.layers.concatenate([u8, c2])
173
+ c8 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(u8)
174
+ c8 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c8)
175
+
176
+ u9 = tf.keras.layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8)
177
+ u9 = tf.keras.layers.concatenate([u9, c1])
178
+ c9 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(u9)
179
+ c9 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c9)
180
+
181
+ outputs = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(c9)
182
+
183
+ model = tf.keras.models.Model(inputs=[inputs], outputs=[outputs])
184
+ return model
185
+
186
+ # Create the alternative model
187
+ alt_model = create_unet_model()
188
+
189
+ # Initialize with random weights (it won't be accurate but it will be functional)
190
+ st.warning("WARNING: Using an alternative U-Net model with random weights.")
191
+ st.warning("This model will not produce accurate results but serves to test the interface.")
192
+
193
+ return alt_model
194
+
195
+ except Exception as e3:
196
+ st.error("All loading attempts have failed.")
197
+ st.error("Errors encountered:")
198
+ st.error(f"1. tf.compat.v1: {e1}")
199
+ st.error(f"2. Signature inspection: {e2}")
200
+ st.error(f"3. Alternative model: {e3}")
201
+
202
+ st.info("Recommended solutions:")
203
+ st.info("1. Use an environment with TensorFlow 2.5 or compatible versions")
204
+ st.info("2. Look for an updated version of the model")
205
+ st.info("3. Contact the author for a version compatible with Keras 3")
206
+
207
+ return None
208
 
209
+ # --- Helper Functions (unchanged) ---
210
  def load_image(image_file):
211
+ """Loads an image from a file path or an uploaded file object."""
212
  img = Image.open(image_file)
213
  return img
214
 
215
  def convert_one_channel(img_array):
216
+ """Ensures the image is single-channel (grayscale)."""
 
217
  if len(img_array.shape) > 2 and img_array.shape[2] > 1:
218
  img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY)
219
  return img_array
220
 
221
  def convert_rgb(img_array):
222
+ """Ensures the image is 3-channel (RGB) for drawing contours."""
 
223
  if len(img_array.shape) == 2:
224
  img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
225
  return img_array
226
 
227
  # --- Streamlit App Layout ---
228
+ st.set_page_config(layout="wide")
229
+ st.header("Segmentation of Teeth in Panoramic X-rays with U-Net")
230
 
231
+ link = 'Check out our Repo on Github! [link](https://github.com/SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net)'
232
  st.markdown(link, unsafe_allow_html=True)
233
 
234
  # Load the model and stop the app if it fails
235
  model = load_keras_model()
236
  if model is None:
237
+ st.warning("The model could not be loaded. The application cannot proceed.")
238
  st.stop()
239
 
240
+ # --- Image Selection Section (unchanged) ---
241
+ st.subheader("Upload a Panoramic X-ray or Select an Example")
242
+
243
+ # Use local paths for the example images
244
+ example_image_paths = {
245
+ "Example 1": "107.png",
246
+ "Example 2": "108.png",
247
+ "Example 3": "109.png"
248
+ }
249
+
250
  image_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
251
 
252
  st.write("---")
253
  st.write("Or choose an example:")
254
+ cols = st.columns(len(example_image_paths))
255
+ selected_example = None
256
 
257
+ for i, (caption, path) in enumerate(example_image_paths.items()):
258
+ with cols[i]:
259
+ try:
260
+ st.image(path, caption=caption, use_container_width=True)
261
+ if st.button(f'Use {caption}'):
262
+ selected_example = path
263
+ except Exception:
264
+ st.error(f"Example image '{path}' not found. Make sure 107.png, 108.png, and 109.png are in the same directory as the script.")
265
 
266
+ if selected_example:
267
+ image_file = selected_example
 
 
268
 
269
+ # --- Processing and Prediction Section (FIXED) ---
 
 
 
 
 
270
  if image_file is not None:
271
  st.write("---")
272
+ col1, col2 = st.columns(2)
273
 
 
274
  original_pil_img = load_image(image_file)
 
275
 
276
+ with col1:
277
+ st.image(original_pil_img, caption="Original Image", use_container_width=True)
278
+
279
+ with st.spinner("Analyzing the image and predicting segmentation..."):
280
  original_np_img = np.array(original_pil_img)
281
+
282
+ # 1. Pre-processing for the model
283
  img_gray = convert_one_channel(original_np_img.copy())
284
  img_resized = cv2.resize(img_gray, (512, 512), interpolation=cv2.INTER_LANCZOS4)
285
  img_normalized = np.float32(img_resized / 255.0)
286
+ img_input_np = np.reshape(img_normalized, (1, 512, 512, 1))
 
 
 
287
 
288
+ # 2. Run the prediction using the wrapper model
289
+ try:
290
+ # DO NOT convert to TensorFlow tensor - pass the numpy array directly
291
+ prediction = model.predict(img_input_np)
292
+
293
+ # Convert the result to a numpy array if it is a tensor
294
+ if hasattr(prediction, 'numpy'):
295
+ prediction = prediction.numpy()
296
+
297
+ except Exception as e:
298
+ st.error(f"Prediction failed. Error: {e}")
299
+ st.code(traceback.format_exc())
300
+ st.stop()
301
 
302
+ # 3. Post-processing of the prediction mask
303
+ # Handle the case where prediction might have different dimensions
304
+ if len(prediction.shape) == 4:
305
+ predicted_mask = prediction[0] # Batch dimension
306
+ else:
307
+ predicted_mask = prediction
308
+
309
+ # If the mask has more than 2 dimensions, take the first channel
310
+ if len(predicted_mask.shape) > 2:
311
+ predicted_mask = predicted_mask[:, :, 0]
312
+
313
+ # Resize the mask to the original image dimensions
314
+ resized_mask = cv2.resize(predicted_mask,
315
+ (original_np_img.shape[1], original_np_img.shape[0]),
316
+ interpolation=cv2.INTER_LANCZOS4)
317
+
318
+ # Binarize the mask with Otsu's threshold for a clean result
319
  mask_8bit = (resized_mask * 255).astype(np.uint8)
320
  _, final_mask = cv2.threshold(mask_8bit, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
321
 
322
+ # Clean the mask with morphological operations to remove noise
323
  kernel = np.ones((5, 5), dtype=np.uint8)
324
+ final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_OPEN, kernel, iterations=2)
325
+ final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
326
 
327
  # Find contours on the final mask
328
  contours, _ = cv2.findContours(final_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
329
 
330
+ # Draw the contours on a color version of the original image
331
  img_for_drawing = convert_rgb(original_np_img.copy())
332
+ output_image = cv2.drawContours(img_for_drawing, contours, -1, (255, 0, 0), 3) # Red contours
333
+
334
+ with col2:
335
+ st.image(output_image, caption="Image with Segmented Teeth", use_container_width=True)
336
 
337
+ st.success("Prediction complete!")