File size: 15,718 Bytes
2067094
 
 
 
 
7d3b243
 
2067094
7d3b243
2067094
 
7d3b243
 
 
 
 
 
 
 
 
 
2067094
7d3b243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2067094
7d3b243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2067094
7d3b243
2067094
7d3b243
2067094
 
 
 
7d3b243
2067094
 
 
 
 
7d3b243
2067094
 
 
 
 
7d3b243
 
2067094
7d3b243
2067094
 
 
 
 
7d3b243
2067094
 
7d3b243
 
 
 
 
 
 
 
 
 
2067094
 
 
 
7d3b243
 
2067094
7d3b243
 
 
 
 
 
 
 
2067094
7d3b243
 
2067094
7d3b243
2067094
 
7d3b243
2067094
 
 
7d3b243
 
 
 
2067094
7d3b243
 
2067094
 
 
7d3b243
2067094
7d3b243
 
 
 
 
 
 
 
 
 
 
 
 
2067094
7d3b243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2067094
 
 
7d3b243
2067094
7d3b243
 
2067094
 
 
 
7d3b243
2067094
7d3b243
 
 
 
2067094
7d3b243
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
import streamlit as st
import tensorflow as tf
from PIL import Image
import numpy as np
import cv2
from huggingface_hub import snapshot_download
import traceback

# --- Model Loading Function (COMPATIBILITY FOCUSED) ---
@st.cache_resource
def load_keras_model():
    """
    Loads a TensorFlow SavedModel, handling compatibility issues
    with legacy optimizers and Keras 3.
    """
    model_repo = "SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net"
    model_path = snapshot_download(repo_id=model_repo)
    st.info(f"Model downloaded to: {model_path}")
    
    # Approach 1: Loading with tf.compat.v1 for legacy compatibility
    st.info("Attempt 1: Loading with tf.compat.v1...")
    try:
        import tensorflow.compat.v1 as tf_v1
        tf_v1.disable_v2_behavior()
        
        # Create a persistent session that doesn't close
        sess = tf_v1.Session()
        
        # Load the meta graph
        tf_v1.saved_model.loader.load(sess, ['serve'], model_path)
        
        # Find input and output tensors
        input_tensor = sess.graph.get_tensor_by_name('serving_default_input_1:0')
        output_tensor = sess.graph.get_tensor_by_name('StatefulPartitionedCall:0')
        
        class TFv1ModelWrapper:
            def __init__(self, sess, input_tensor, output_tensor):
                self.sess = sess
                self.input_tensor = input_tensor
                self.output_tensor = output_tensor
                
            def predict(self, input_data):
                # Convert input to numpy array
                if hasattr(input_data, 'numpy'):
                    # If it's an EagerTensor, use .numpy()
                    input_data = input_data.numpy()
                elif isinstance(input_data, tf.Tensor):
                    # If it's a SymbolicTensor or other tensor type, use tf.Session.run
                    with tf_v1.Session() as temp_sess:
                        input_data = temp_sess.run(input_data)
                elif not isinstance(input_data, np.ndarray):
                    # Convert to numpy array if it isn't already
                    input_data = np.array(input_data)
                
                # Run prediction
                result = self.sess.run(self.output_tensor, 
                                     feed_dict={self.input_tensor: input_data})
                return result
                
            def __del__(self):
                # Close the session when the object is deleted
                try:
                    if hasattr(self, 'sess') and self.sess is not None:
                        self.sess.close()
                except:
                    pass
        
        model = TFv1ModelWrapper(sess, input_tensor, output_tensor)
        st.success("Model loaded successfully using tf.compat.v1!")
        return model
            
    except Exception as e1:
        st.warning(f"Attempt 1 failed: {e1}")
        
        # Approach 2: Loading with signature inspection
        st.info("Attempt 2: Loading with signature inspection...")
        try:
            # Load just to inspect the signatures
            loaded_model = tf.saved_model.load(model_path)
            
            # Get information about the signatures
            signatures = loaded_model.signatures
            st.info(f"Available signatures: {list(signatures.keys())}")
            
            if signatures:
                # Use the first available signature
                signature_key = list(signatures.keys())[0]
                signature = signatures[signature_key]
                
                class SignatureModelWrapper:
                    def __init__(self, signature):
                        self.signature = signature
                        
                    def predict(self, input_data):
                        # Convert input to numpy array before converting to tensor
                        if hasattr(input_data, 'numpy'):
                            input_data = input_data.numpy()
                        elif isinstance(input_data, tf.Tensor):
                            # For SymbolicTensor, try to evaluate it
                            try:
                                input_data = tf.keras.backend.eval(input_data)
                            except:
                                # If it fails, convert to numpy using a different approach
                                input_data = np.array(input_data)
                        
                        # Now convert to a TensorFlow tensor
                        if not isinstance(input_data, tf.Tensor):
                            input_data = tf.convert_to_tensor(input_data, dtype=tf.float32)
                        
                        # Get the name of the first input
                        input_specs = self.signature.structured_input_signature[1]
                        input_name = list(input_specs.keys())[0]
                        
                        # Run prediction
                        result = self.signature(**{input_name: input_data})
                        
                        # Handle output
                        if isinstance(result, dict):
                            result = list(result.values())[0]
                        
                        return result
                
                model = SignatureModelWrapper(signature)
                st.success(f"Model loaded successfully using signature: {signature_key}!")
                return model
            else:
                raise Exception("No signatures found in the model")
                
        except Exception as e2:
            st.warning(f"Attempt 2 failed: {e2}")
            
            # Approach 3: Creation of an alternative model
            st.info("Attempt 3: Creating an alternative U-Net model...")
            try:
                # Create a simple U-Net model as a fallback
                def create_unet_model(input_shape=(512, 512, 1)):
                    inputs = tf.keras.layers.Input(shape=input_shape)
                    
                    # Encoder
                    c1 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
                    c1 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
                    p1 = tf.keras.layers.MaxPooling2D((2, 2))(c1)
                    
                    c2 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
                    c2 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
                    p2 = tf.keras.layers.MaxPooling2D((2, 2))(c2)
                    
                    c3 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
                    c3 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c3)
                    p3 = tf.keras.layers.MaxPooling2D((2, 2))(c3)
                    
                    c4 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(p3)
                    c4 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c4)
                    p4 = tf.keras.layers.MaxPooling2D((2, 2))(c4)
                    
                    # Bottleneck
                    c5 = tf.keras.layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(p4)
                    c5 = tf.keras.layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(c5)
                    
                    # Decoder
                    u6 = tf.keras.layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(c5)
                    u6 = tf.keras.layers.concatenate([u6, c4])
                    c6 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(u6)
                    c6 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c6)
                    
                    u7 = tf.keras.layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6)
                    u7 = tf.keras.layers.concatenate([u7, c3])
                    c7 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(u7)
                    c7 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c7)
                    
                    u8 = tf.keras.layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7)
                    u8 = tf.keras.layers.concatenate([u8, c2])
                    c8 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(u8)
                    c8 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c8)
                    
                    u9 = tf.keras.layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8)
                    u9 = tf.keras.layers.concatenate([u9, c1])
                    c9 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(u9)
                    c9 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c9)
                    
                    outputs = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(c9)
                    
                    model = tf.keras.models.Model(inputs=[inputs], outputs=[outputs])
                    return model
                
                # Create the alternative model
                alt_model = create_unet_model()
                
                # Initialize with random weights (it won't be accurate but it will be functional)
                st.warning("WARNING: Using an alternative U-Net model with random weights.")
                st.warning("This model will not produce accurate results but serves to test the interface.")
                
                return alt_model
                
            except Exception as e3:
                st.error("All loading attempts have failed.")
                st.error("Errors encountered:")
                st.error(f"1. tf.compat.v1: {e1}")
                st.error(f"2. Signature inspection: {e2}")
                st.error(f"3. Alternative model: {e3}")
                
                st.info("Recommended solutions:")
                st.info("1. Use an environment with TensorFlow 2.5 or compatible versions")
                st.info("2. Look for an updated version of the model")
                st.info("3. Contact the author for a version compatible with Keras 3")
                
                return None

# --- Helper Functions (unchanged) ---
def load_image(image_file):
    """Loads an image from a file path or an uploaded file object."""
    img = Image.open(image_file)
    return img

def convert_one_channel(img_array):
    """Ensures the image is single-channel (grayscale)."""
    if len(img_array.shape) > 2 and img_array.shape[2] > 1:
        img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY)
    return img_array

def convert_rgb(img_array):
    """Ensures the image is 3-channel (RGB) for drawing contours."""
    if len(img_array.shape) == 2:
        img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
    return img_array

# --- Streamlit App Layout ---
st.set_page_config(layout="wide")
st.header("Segmentation of Teeth in Panoramic X-rays with U-Net")

link = 'Check out our Repo on Github! [link](https://github.com/SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net)'
st.markdown(link, unsafe_allow_html=True)

# Load the model and stop the app if it fails
model = load_keras_model()
if model is None:
    st.warning("The model could not be loaded. The application cannot proceed.")
    st.stop()

# --- Image Selection Section (unchanged) ---
st.subheader("Upload a Panoramic X-ray or Select an Example")

# Use local paths for the example images
example_image_paths = {
    "Example 1": "107.png",
    "Example 2": "108.png", 
    "Example 3": "109.png"
}

image_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])

st.write("---")
st.write("Or choose an example:")
cols = st.columns(len(example_image_paths))
selected_example = None

for i, (caption, path) in enumerate(example_image_paths.items()):
    with cols[i]:
        try:
            st.image(path, caption=caption, use_container_width=True)
            if st.button(f'Use {caption}'):
                selected_example = path
        except Exception:
            st.error(f"Example image '{path}' not found. Make sure 107.png, 108.png, and 109.png are in the same directory as the script.")

if selected_example:
    image_file = selected_example

# --- Processing and Prediction Section (FIXED) ---
if image_file is not None:
    st.write("---")
    col1, col2 = st.columns(2)
    
    original_pil_img = load_image(image_file)
    
    with col1:
        st.image(original_pil_img, caption="Original Image", use_container_width=True)
    
    with st.spinner("Analyzing the image and predicting segmentation..."):
        original_np_img = np.array(original_pil_img)
        
        # 1. Pre-processing for the model
        img_gray = convert_one_channel(original_np_img.copy())
        img_resized = cv2.resize(img_gray, (512, 512), interpolation=cv2.INTER_LANCZOS4)
        img_normalized = np.float32(img_resized / 255.0)
        img_input_np = np.reshape(img_normalized, (1, 512, 512, 1))
        
        # 2. Run the prediction using the wrapper model
        try:
            # DO NOT convert to TensorFlow tensor - pass the numpy array directly
            prediction = model.predict(img_input_np)
            
            # Convert the result to a numpy array if it is a tensor
            if hasattr(prediction, 'numpy'):
                prediction = prediction.numpy()
                
        except Exception as e:
            st.error(f"Prediction failed. Error: {e}")
            st.code(traceback.format_exc())
            st.stop()
        
        # 3. Post-processing of the prediction mask
        # Handle the case where prediction might have different dimensions
        if len(prediction.shape) == 4:
            predicted_mask = prediction[0]  # Batch dimension
        else:
            predicted_mask = prediction
            
        # If the mask has more than 2 dimensions, take the first channel
        if len(predicted_mask.shape) > 2:
            predicted_mask = predicted_mask[:, :, 0]
        
        # Resize the mask to the original image dimensions
        resized_mask = cv2.resize(predicted_mask, 
                                (original_np_img.shape[1], original_np_img.shape[0]), 
                                interpolation=cv2.INTER_LANCZOS4)
        
        # Binarize the mask with Otsu's threshold for a clean result
        mask_8bit = (resized_mask * 255).astype(np.uint8)
        _, final_mask = cv2.threshold(mask_8bit, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        
        # Clean the mask with morphological operations to remove noise
        kernel = np.ones((5, 5), dtype=np.uint8)
        final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_OPEN, kernel, iterations=2)
        final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
        
        # Find contours on the final mask
        contours, _ = cv2.findContours(final_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        
        # Draw the contours on a color version of the original image
        img_for_drawing = convert_rgb(original_np_img.copy())
        output_image = cv2.drawContours(img_for_drawing, contours, -1, (255, 0, 0), 3)  # Red contours
        
        with col2:
            st.image(output_image, caption="Image with Segmented Teeth", use_container_width=True)
        
        st.success("Prediction complete!")