|
import streamlit as st |
|
import tensorflow as tf |
|
from PIL import Image |
|
import numpy as np |
|
import cv2 |
|
from huggingface_hub import snapshot_download |
|
import traceback |
|
|
|
|
|
@st.cache_resource |
|
def load_keras_model(): |
|
""" |
|
Loads a TensorFlow SavedModel, handling compatibility issues |
|
with legacy optimizers and Keras 3. |
|
""" |
|
model_repo = "SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net" |
|
model_path = snapshot_download(repo_id=model_repo) |
|
st.info(f"Model downloaded to: {model_path}") |
|
|
|
|
|
st.info("Attempt 1: Loading with tf.compat.v1...") |
|
try: |
|
import tensorflow.compat.v1 as tf_v1 |
|
tf_v1.disable_v2_behavior() |
|
|
|
|
|
sess = tf_v1.Session() |
|
|
|
|
|
tf_v1.saved_model.loader.load(sess, ['serve'], model_path) |
|
|
|
|
|
input_tensor = sess.graph.get_tensor_by_name('serving_default_input_1:0') |
|
output_tensor = sess.graph.get_tensor_by_name('StatefulPartitionedCall:0') |
|
|
|
class TFv1ModelWrapper: |
|
def __init__(self, sess, input_tensor, output_tensor): |
|
self.sess = sess |
|
self.input_tensor = input_tensor |
|
self.output_tensor = output_tensor |
|
|
|
def predict(self, input_data): |
|
|
|
if hasattr(input_data, 'numpy'): |
|
|
|
input_data = input_data.numpy() |
|
elif isinstance(input_data, tf.Tensor): |
|
|
|
with tf_v1.Session() as temp_sess: |
|
input_data = temp_sess.run(input_data) |
|
elif not isinstance(input_data, np.ndarray): |
|
|
|
input_data = np.array(input_data) |
|
|
|
|
|
result = self.sess.run(self.output_tensor, |
|
feed_dict={self.input_tensor: input_data}) |
|
return result |
|
|
|
def __del__(self): |
|
|
|
try: |
|
if hasattr(self, 'sess') and self.sess is not None: |
|
self.sess.close() |
|
except: |
|
pass |
|
|
|
model = TFv1ModelWrapper(sess, input_tensor, output_tensor) |
|
st.success("Model loaded successfully using tf.compat.v1!") |
|
return model |
|
|
|
except Exception as e1: |
|
st.warning(f"Attempt 1 failed: {e1}") |
|
|
|
|
|
st.info("Attempt 2: Loading with signature inspection...") |
|
try: |
|
|
|
loaded_model = tf.saved_model.load(model_path) |
|
|
|
|
|
signatures = loaded_model.signatures |
|
st.info(f"Available signatures: {list(signatures.keys())}") |
|
|
|
if signatures: |
|
|
|
signature_key = list(signatures.keys())[0] |
|
signature = signatures[signature_key] |
|
|
|
class SignatureModelWrapper: |
|
def __init__(self, signature): |
|
self.signature = signature |
|
|
|
def predict(self, input_data): |
|
|
|
if hasattr(input_data, 'numpy'): |
|
input_data = input_data.numpy() |
|
elif isinstance(input_data, tf.Tensor): |
|
|
|
try: |
|
input_data = tf.keras.backend.eval(input_data) |
|
except: |
|
|
|
input_data = np.array(input_data) |
|
|
|
|
|
if not isinstance(input_data, tf.Tensor): |
|
input_data = tf.convert_to_tensor(input_data, dtype=tf.float32) |
|
|
|
|
|
input_specs = self.signature.structured_input_signature[1] |
|
input_name = list(input_specs.keys())[0] |
|
|
|
|
|
result = self.signature(**{input_name: input_data}) |
|
|
|
|
|
if isinstance(result, dict): |
|
result = list(result.values())[0] |
|
|
|
return result |
|
|
|
model = SignatureModelWrapper(signature) |
|
st.success(f"Model loaded successfully using signature: {signature_key}!") |
|
return model |
|
else: |
|
raise Exception("No signatures found in the model") |
|
|
|
except Exception as e2: |
|
st.warning(f"Attempt 2 failed: {e2}") |
|
|
|
|
|
st.info("Attempt 3: Creating an alternative U-Net model...") |
|
try: |
|
|
|
def create_unet_model(input_shape=(512, 512, 1)): |
|
inputs = tf.keras.layers.Input(shape=input_shape) |
|
|
|
|
|
c1 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs) |
|
c1 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c1) |
|
p1 = tf.keras.layers.MaxPooling2D((2, 2))(c1) |
|
|
|
c2 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(p1) |
|
c2 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c2) |
|
p2 = tf.keras.layers.MaxPooling2D((2, 2))(c2) |
|
|
|
c3 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(p2) |
|
c3 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c3) |
|
p3 = tf.keras.layers.MaxPooling2D((2, 2))(c3) |
|
|
|
c4 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(p3) |
|
c4 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c4) |
|
p4 = tf.keras.layers.MaxPooling2D((2, 2))(c4) |
|
|
|
|
|
c5 = tf.keras.layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(p4) |
|
c5 = tf.keras.layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(c5) |
|
|
|
|
|
u6 = tf.keras.layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(c5) |
|
u6 = tf.keras.layers.concatenate([u6, c4]) |
|
c6 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(u6) |
|
c6 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c6) |
|
|
|
u7 = tf.keras.layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6) |
|
u7 = tf.keras.layers.concatenate([u7, c3]) |
|
c7 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(u7) |
|
c7 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c7) |
|
|
|
u8 = tf.keras.layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7) |
|
u8 = tf.keras.layers.concatenate([u8, c2]) |
|
c8 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(u8) |
|
c8 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c8) |
|
|
|
u9 = tf.keras.layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8) |
|
u9 = tf.keras.layers.concatenate([u9, c1]) |
|
c9 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(u9) |
|
c9 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c9) |
|
|
|
outputs = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(c9) |
|
|
|
model = tf.keras.models.Model(inputs=[inputs], outputs=[outputs]) |
|
return model |
|
|
|
|
|
alt_model = create_unet_model() |
|
|
|
|
|
st.warning("WARNING: Using an alternative U-Net model with random weights.") |
|
st.warning("This model will not produce accurate results but serves to test the interface.") |
|
|
|
return alt_model |
|
|
|
except Exception as e3: |
|
st.error("All loading attempts have failed.") |
|
st.error("Errors encountered:") |
|
st.error(f"1. tf.compat.v1: {e1}") |
|
st.error(f"2. Signature inspection: {e2}") |
|
st.error(f"3. Alternative model: {e3}") |
|
|
|
st.info("Recommended solutions:") |
|
st.info("1. Use an environment with TensorFlow 2.5 or compatible versions") |
|
st.info("2. Look for an updated version of the model") |
|
st.info("3. Contact the author for a version compatible with Keras 3") |
|
|
|
return None |
|
|
|
|
|
def load_image(image_file): |
|
"""Loads an image from a file path or an uploaded file object.""" |
|
img = Image.open(image_file) |
|
return img |
|
|
|
def convert_one_channel(img_array): |
|
"""Ensures the image is single-channel (grayscale).""" |
|
if len(img_array.shape) > 2 and img_array.shape[2] > 1: |
|
img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY) |
|
return img_array |
|
|
|
def convert_rgb(img_array): |
|
"""Ensures the image is 3-channel (RGB) for drawing contours.""" |
|
if len(img_array.shape) == 2: |
|
img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB) |
|
return img_array |
|
|
|
|
|
st.set_page_config(layout="wide") |
|
st.header("Segmentation of Teeth in Panoramic X-rays with U-Net") |
|
|
|
link = 'Check out our Repo on Github! [link](https://github.com/SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net)' |
|
st.markdown(link, unsafe_allow_html=True) |
|
|
|
|
|
model = load_keras_model() |
|
if model is None: |
|
st.warning("The model could not be loaded. The application cannot proceed.") |
|
st.stop() |
|
|
|
|
|
st.subheader("Upload a Panoramic X-ray or Select an Example") |
|
|
|
|
|
example_image_paths = { |
|
"Example 1": "107.png", |
|
"Example 2": "108.png", |
|
"Example 3": "109.png" |
|
} |
|
|
|
image_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"]) |
|
|
|
st.write("---") |
|
st.write("Or choose an example:") |
|
cols = st.columns(len(example_image_paths)) |
|
selected_example = None |
|
|
|
for i, (caption, path) in enumerate(example_image_paths.items()): |
|
with cols[i]: |
|
try: |
|
st.image(path, caption=caption, use_container_width=True) |
|
if st.button(f'Use {caption}'): |
|
selected_example = path |
|
except Exception: |
|
st.error(f"Example image '{path}' not found. Make sure 107.png, 108.png, and 109.png are in the same directory as the script.") |
|
|
|
if selected_example: |
|
image_file = selected_example |
|
|
|
|
|
if image_file is not None: |
|
st.write("---") |
|
col1, col2 = st.columns(2) |
|
|
|
original_pil_img = load_image(image_file) |
|
|
|
with col1: |
|
st.image(original_pil_img, caption="Original Image", use_container_width=True) |
|
|
|
with st.spinner("Analyzing the image and predicting segmentation..."): |
|
original_np_img = np.array(original_pil_img) |
|
|
|
|
|
img_gray = convert_one_channel(original_np_img.copy()) |
|
img_resized = cv2.resize(img_gray, (512, 512), interpolation=cv2.INTER_LANCZOS4) |
|
img_normalized = np.float32(img_resized / 255.0) |
|
img_input_np = np.reshape(img_normalized, (1, 512, 512, 1)) |
|
|
|
|
|
try: |
|
|
|
prediction = model.predict(img_input_np) |
|
|
|
|
|
if hasattr(prediction, 'numpy'): |
|
prediction = prediction.numpy() |
|
|
|
except Exception as e: |
|
st.error(f"Prediction failed. Error: {e}") |
|
st.code(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
|
|
if len(prediction.shape) == 4: |
|
predicted_mask = prediction[0] |
|
else: |
|
predicted_mask = prediction |
|
|
|
|
|
if len(predicted_mask.shape) > 2: |
|
predicted_mask = predicted_mask[:, :, 0] |
|
|
|
|
|
resized_mask = cv2.resize(predicted_mask, |
|
(original_np_img.shape[1], original_np_img.shape[0]), |
|
interpolation=cv2.INTER_LANCZOS4) |
|
|
|
|
|
mask_8bit = (resized_mask * 255).astype(np.uint8) |
|
_, final_mask = cv2.threshold(mask_8bit, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
|
|
|
|
|
kernel = np.ones((5, 5), dtype=np.uint8) |
|
final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_OPEN, kernel, iterations=2) |
|
final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_CLOSE, kernel, iterations=2) |
|
|
|
|
|
contours, _ = cv2.findContours(final_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
img_for_drawing = convert_rgb(original_np_img.copy()) |
|
output_image = cv2.drawContours(img_for_drawing, contours, -1, (255, 0, 0), 3) |
|
|
|
with col2: |
|
st.image(output_image, caption="Image with Segmented Teeth", use_container_width=True) |
|
|
|
st.success("Prediction complete!") |