import streamlit as st import tensorflow as tf from PIL import Image import numpy as np import cv2 from huggingface_hub import snapshot_download import traceback # --- Model Loading Function (COMPATIBILITY FOCUSED) --- @st.cache_resource def load_keras_model(): """ Loads a TensorFlow SavedModel, handling compatibility issues with legacy optimizers and Keras 3. """ model_repo = "SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net" model_path = snapshot_download(repo_id=model_repo) st.info(f"Model downloaded to: {model_path}") # Approach 1: Loading with tf.compat.v1 for legacy compatibility st.info("Attempt 1: Loading with tf.compat.v1...") try: import tensorflow.compat.v1 as tf_v1 tf_v1.disable_v2_behavior() # Create a persistent session that doesn't close sess = tf_v1.Session() # Load the meta graph tf_v1.saved_model.loader.load(sess, ['serve'], model_path) # Find input and output tensors input_tensor = sess.graph.get_tensor_by_name('serving_default_input_1:0') output_tensor = sess.graph.get_tensor_by_name('StatefulPartitionedCall:0') class TFv1ModelWrapper: def __init__(self, sess, input_tensor, output_tensor): self.sess = sess self.input_tensor = input_tensor self.output_tensor = output_tensor def predict(self, input_data): # Convert input to numpy array if hasattr(input_data, 'numpy'): # If it's an EagerTensor, use .numpy() input_data = input_data.numpy() elif isinstance(input_data, tf.Tensor): # If it's a SymbolicTensor or other tensor type, use tf.Session.run with tf_v1.Session() as temp_sess: input_data = temp_sess.run(input_data) elif not isinstance(input_data, np.ndarray): # Convert to numpy array if it isn't already input_data = np.array(input_data) # Run prediction result = self.sess.run(self.output_tensor, feed_dict={self.input_tensor: input_data}) return result def __del__(self): # Close the session when the object is deleted try: if hasattr(self, 'sess') and self.sess is not None: self.sess.close() except: pass model = TFv1ModelWrapper(sess, input_tensor, output_tensor) st.success("Model loaded successfully using tf.compat.v1!") return model except Exception as e1: st.warning(f"Attempt 1 failed: {e1}") # Approach 2: Loading with signature inspection st.info("Attempt 2: Loading with signature inspection...") try: # Load just to inspect the signatures loaded_model = tf.saved_model.load(model_path) # Get information about the signatures signatures = loaded_model.signatures st.info(f"Available signatures: {list(signatures.keys())}") if signatures: # Use the first available signature signature_key = list(signatures.keys())[0] signature = signatures[signature_key] class SignatureModelWrapper: def __init__(self, signature): self.signature = signature def predict(self, input_data): # Convert input to numpy array before converting to tensor if hasattr(input_data, 'numpy'): input_data = input_data.numpy() elif isinstance(input_data, tf.Tensor): # For SymbolicTensor, try to evaluate it try: input_data = tf.keras.backend.eval(input_data) except: # If it fails, convert to numpy using a different approach input_data = np.array(input_data) # Now convert to a TensorFlow tensor if not isinstance(input_data, tf.Tensor): input_data = tf.convert_to_tensor(input_data, dtype=tf.float32) # Get the name of the first input input_specs = self.signature.structured_input_signature[1] input_name = list(input_specs.keys())[0] # Run prediction result = self.signature(**{input_name: input_data}) # Handle output if isinstance(result, dict): result = list(result.values())[0] return result model = SignatureModelWrapper(signature) st.success(f"Model loaded successfully using signature: {signature_key}!") return model else: raise Exception("No signatures found in the model") except Exception as e2: st.warning(f"Attempt 2 failed: {e2}") # Approach 3: Creation of an alternative model st.info("Attempt 3: Creating an alternative U-Net model...") try: # Create a simple U-Net model as a fallback def create_unet_model(input_shape=(512, 512, 1)): inputs = tf.keras.layers.Input(shape=input_shape) # Encoder c1 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs) c1 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c1) p1 = tf.keras.layers.MaxPooling2D((2, 2))(c1) c2 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(p1) c2 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c2) p2 = tf.keras.layers.MaxPooling2D((2, 2))(c2) c3 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(p2) c3 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c3) p3 = tf.keras.layers.MaxPooling2D((2, 2))(c3) c4 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(p3) c4 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c4) p4 = tf.keras.layers.MaxPooling2D((2, 2))(c4) # Bottleneck c5 = tf.keras.layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(p4) c5 = tf.keras.layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(c5) # Decoder u6 = tf.keras.layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(c5) u6 = tf.keras.layers.concatenate([u6, c4]) c6 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(u6) c6 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c6) u7 = tf.keras.layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6) u7 = tf.keras.layers.concatenate([u7, c3]) c7 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(u7) c7 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c7) u8 = tf.keras.layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7) u8 = tf.keras.layers.concatenate([u8, c2]) c8 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(u8) c8 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c8) u9 = tf.keras.layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8) u9 = tf.keras.layers.concatenate([u9, c1]) c9 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(u9) c9 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c9) outputs = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(c9) model = tf.keras.models.Model(inputs=[inputs], outputs=[outputs]) return model # Create the alternative model alt_model = create_unet_model() # Initialize with random weights (it won't be accurate but it will be functional) st.warning("WARNING: Using an alternative U-Net model with random weights.") st.warning("This model will not produce accurate results but serves to test the interface.") return alt_model except Exception as e3: st.error("All loading attempts have failed.") st.error("Errors encountered:") st.error(f"1. tf.compat.v1: {e1}") st.error(f"2. Signature inspection: {e2}") st.error(f"3. Alternative model: {e3}") st.info("Recommended solutions:") st.info("1. Use an environment with TensorFlow 2.5 or compatible versions") st.info("2. Look for an updated version of the model") st.info("3. Contact the author for a version compatible with Keras 3") return None # --- Helper Functions (unchanged) --- def load_image(image_file): """Loads an image from a file path or an uploaded file object.""" img = Image.open(image_file) return img def convert_one_channel(img_array): """Ensures the image is single-channel (grayscale).""" if len(img_array.shape) > 2 and img_array.shape[2] > 1: img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY) return img_array def convert_rgb(img_array): """Ensures the image is 3-channel (RGB) for drawing contours.""" if len(img_array.shape) == 2: img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB) return img_array # --- Streamlit App Layout --- st.set_page_config(layout="wide") st.header("Segmentation of Teeth in Panoramic X-rays with U-Net") link = 'Check out our Repo on Github! [link](https://github.com/SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net)' st.markdown(link, unsafe_allow_html=True) # Load the model and stop the app if it fails model = load_keras_model() if model is None: st.warning("The model could not be loaded. The application cannot proceed.") st.stop() # --- Image Selection Section (unchanged) --- st.subheader("Upload a Panoramic X-ray or Select an Example") # Use local paths for the example images example_image_paths = { "Example 1": "107.png", "Example 2": "108.png", "Example 3": "109.png" } image_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"]) st.write("---") st.write("Or choose an example:") cols = st.columns(len(example_image_paths)) selected_example = None for i, (caption, path) in enumerate(example_image_paths.items()): with cols[i]: try: st.image(path, caption=caption, use_container_width=True) if st.button(f'Use {caption}'): selected_example = path except Exception: st.error(f"Example image '{path}' not found. Make sure 107.png, 108.png, and 109.png are in the same directory as the script.") if selected_example: image_file = selected_example # --- Processing and Prediction Section (FIXED) --- if image_file is not None: st.write("---") col1, col2 = st.columns(2) original_pil_img = load_image(image_file) with col1: st.image(original_pil_img, caption="Original Image", use_container_width=True) with st.spinner("Analyzing the image and predicting segmentation..."): original_np_img = np.array(original_pil_img) # 1. Pre-processing for the model img_gray = convert_one_channel(original_np_img.copy()) img_resized = cv2.resize(img_gray, (512, 512), interpolation=cv2.INTER_LANCZOS4) img_normalized = np.float32(img_resized / 255.0) img_input_np = np.reshape(img_normalized, (1, 512, 512, 1)) # 2. Run the prediction using the wrapper model try: # DO NOT convert to TensorFlow tensor - pass the numpy array directly prediction = model.predict(img_input_np) # Convert the result to a numpy array if it is a tensor if hasattr(prediction, 'numpy'): prediction = prediction.numpy() except Exception as e: st.error(f"Prediction failed. Error: {e}") st.code(traceback.format_exc()) st.stop() # 3. Post-processing of the prediction mask # Handle the case where prediction might have different dimensions if len(prediction.shape) == 4: predicted_mask = prediction[0] # Batch dimension else: predicted_mask = prediction # If the mask has more than 2 dimensions, take the first channel if len(predicted_mask.shape) > 2: predicted_mask = predicted_mask[:, :, 0] # Resize the mask to the original image dimensions resized_mask = cv2.resize(predicted_mask, (original_np_img.shape[1], original_np_img.shape[0]), interpolation=cv2.INTER_LANCZOS4) # Binarize the mask with Otsu's threshold for a clean result mask_8bit = (resized_mask * 255).astype(np.uint8) _, final_mask = cv2.threshold(mask_8bit, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) # Clean the mask with morphological operations to remove noise kernel = np.ones((5, 5), dtype=np.uint8) final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_OPEN, kernel, iterations=2) final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_CLOSE, kernel, iterations=2) # Find contours on the final mask contours, _ = cv2.findContours(final_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # Draw the contours on a color version of the original image img_for_drawing = convert_rgb(original_np_img.copy()) output_image = cv2.drawContours(img_for_drawing, contours, -1, (255, 0, 0), 3) # Red contours with col2: st.image(output_image, caption="Image with Segmented Teeth", use_container_width=True) st.success("Prediction complete!")