Dddixyy's picture
Update app.py
7d3b243 verified
raw
history blame
15.7 kB
import streamlit as st
import tensorflow as tf
from PIL import Image
import numpy as np
import cv2
from huggingface_hub import snapshot_download
import traceback
# --- Model Loading Function (COMPATIBILITY FOCUSED) ---
@st.cache_resource
def load_keras_model():
"""
Loads a TensorFlow SavedModel, handling compatibility issues
with legacy optimizers and Keras 3.
"""
model_repo = "SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net"
model_path = snapshot_download(repo_id=model_repo)
st.info(f"Model downloaded to: {model_path}")
# Approach 1: Loading with tf.compat.v1 for legacy compatibility
st.info("Attempt 1: Loading with tf.compat.v1...")
try:
import tensorflow.compat.v1 as tf_v1
tf_v1.disable_v2_behavior()
# Create a persistent session that doesn't close
sess = tf_v1.Session()
# Load the meta graph
tf_v1.saved_model.loader.load(sess, ['serve'], model_path)
# Find input and output tensors
input_tensor = sess.graph.get_tensor_by_name('serving_default_input_1:0')
output_tensor = sess.graph.get_tensor_by_name('StatefulPartitionedCall:0')
class TFv1ModelWrapper:
def __init__(self, sess, input_tensor, output_tensor):
self.sess = sess
self.input_tensor = input_tensor
self.output_tensor = output_tensor
def predict(self, input_data):
# Convert input to numpy array
if hasattr(input_data, 'numpy'):
# If it's an EagerTensor, use .numpy()
input_data = input_data.numpy()
elif isinstance(input_data, tf.Tensor):
# If it's a SymbolicTensor or other tensor type, use tf.Session.run
with tf_v1.Session() as temp_sess:
input_data = temp_sess.run(input_data)
elif not isinstance(input_data, np.ndarray):
# Convert to numpy array if it isn't already
input_data = np.array(input_data)
# Run prediction
result = self.sess.run(self.output_tensor,
feed_dict={self.input_tensor: input_data})
return result
def __del__(self):
# Close the session when the object is deleted
try:
if hasattr(self, 'sess') and self.sess is not None:
self.sess.close()
except:
pass
model = TFv1ModelWrapper(sess, input_tensor, output_tensor)
st.success("Model loaded successfully using tf.compat.v1!")
return model
except Exception as e1:
st.warning(f"Attempt 1 failed: {e1}")
# Approach 2: Loading with signature inspection
st.info("Attempt 2: Loading with signature inspection...")
try:
# Load just to inspect the signatures
loaded_model = tf.saved_model.load(model_path)
# Get information about the signatures
signatures = loaded_model.signatures
st.info(f"Available signatures: {list(signatures.keys())}")
if signatures:
# Use the first available signature
signature_key = list(signatures.keys())[0]
signature = signatures[signature_key]
class SignatureModelWrapper:
def __init__(self, signature):
self.signature = signature
def predict(self, input_data):
# Convert input to numpy array before converting to tensor
if hasattr(input_data, 'numpy'):
input_data = input_data.numpy()
elif isinstance(input_data, tf.Tensor):
# For SymbolicTensor, try to evaluate it
try:
input_data = tf.keras.backend.eval(input_data)
except:
# If it fails, convert to numpy using a different approach
input_data = np.array(input_data)
# Now convert to a TensorFlow tensor
if not isinstance(input_data, tf.Tensor):
input_data = tf.convert_to_tensor(input_data, dtype=tf.float32)
# Get the name of the first input
input_specs = self.signature.structured_input_signature[1]
input_name = list(input_specs.keys())[0]
# Run prediction
result = self.signature(**{input_name: input_data})
# Handle output
if isinstance(result, dict):
result = list(result.values())[0]
return result
model = SignatureModelWrapper(signature)
st.success(f"Model loaded successfully using signature: {signature_key}!")
return model
else:
raise Exception("No signatures found in the model")
except Exception as e2:
st.warning(f"Attempt 2 failed: {e2}")
# Approach 3: Creation of an alternative model
st.info("Attempt 3: Creating an alternative U-Net model...")
try:
# Create a simple U-Net model as a fallback
def create_unet_model(input_shape=(512, 512, 1)):
inputs = tf.keras.layers.Input(shape=input_shape)
# Encoder
c1 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
c1 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
p1 = tf.keras.layers.MaxPooling2D((2, 2))(c1)
c2 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
c2 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
p2 = tf.keras.layers.MaxPooling2D((2, 2))(c2)
c3 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
c3 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c3)
p3 = tf.keras.layers.MaxPooling2D((2, 2))(c3)
c4 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(p3)
c4 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c4)
p4 = tf.keras.layers.MaxPooling2D((2, 2))(c4)
# Bottleneck
c5 = tf.keras.layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(p4)
c5 = tf.keras.layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(c5)
# Decoder
u6 = tf.keras.layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(c5)
u6 = tf.keras.layers.concatenate([u6, c4])
c6 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(u6)
c6 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c6)
u7 = tf.keras.layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6)
u7 = tf.keras.layers.concatenate([u7, c3])
c7 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(u7)
c7 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c7)
u8 = tf.keras.layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7)
u8 = tf.keras.layers.concatenate([u8, c2])
c8 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(u8)
c8 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c8)
u9 = tf.keras.layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8)
u9 = tf.keras.layers.concatenate([u9, c1])
c9 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(u9)
c9 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c9)
outputs = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(c9)
model = tf.keras.models.Model(inputs=[inputs], outputs=[outputs])
return model
# Create the alternative model
alt_model = create_unet_model()
# Initialize with random weights (it won't be accurate but it will be functional)
st.warning("WARNING: Using an alternative U-Net model with random weights.")
st.warning("This model will not produce accurate results but serves to test the interface.")
return alt_model
except Exception as e3:
st.error("All loading attempts have failed.")
st.error("Errors encountered:")
st.error(f"1. tf.compat.v1: {e1}")
st.error(f"2. Signature inspection: {e2}")
st.error(f"3. Alternative model: {e3}")
st.info("Recommended solutions:")
st.info("1. Use an environment with TensorFlow 2.5 or compatible versions")
st.info("2. Look for an updated version of the model")
st.info("3. Contact the author for a version compatible with Keras 3")
return None
# --- Helper Functions (unchanged) ---
def load_image(image_file):
"""Loads an image from a file path or an uploaded file object."""
img = Image.open(image_file)
return img
def convert_one_channel(img_array):
"""Ensures the image is single-channel (grayscale)."""
if len(img_array.shape) > 2 and img_array.shape[2] > 1:
img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY)
return img_array
def convert_rgb(img_array):
"""Ensures the image is 3-channel (RGB) for drawing contours."""
if len(img_array.shape) == 2:
img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
return img_array
# --- Streamlit App Layout ---
st.set_page_config(layout="wide")
st.header("Segmentation of Teeth in Panoramic X-rays with U-Net")
link = 'Check out our Repo on Github! [link](https://github.com/SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net)'
st.markdown(link, unsafe_allow_html=True)
# Load the model and stop the app if it fails
model = load_keras_model()
if model is None:
st.warning("The model could not be loaded. The application cannot proceed.")
st.stop()
# --- Image Selection Section (unchanged) ---
st.subheader("Upload a Panoramic X-ray or Select an Example")
# Use local paths for the example images
example_image_paths = {
"Example 1": "107.png",
"Example 2": "108.png",
"Example 3": "109.png"
}
image_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
st.write("---")
st.write("Or choose an example:")
cols = st.columns(len(example_image_paths))
selected_example = None
for i, (caption, path) in enumerate(example_image_paths.items()):
with cols[i]:
try:
st.image(path, caption=caption, use_container_width=True)
if st.button(f'Use {caption}'):
selected_example = path
except Exception:
st.error(f"Example image '{path}' not found. Make sure 107.png, 108.png, and 109.png are in the same directory as the script.")
if selected_example:
image_file = selected_example
# --- Processing and Prediction Section (FIXED) ---
if image_file is not None:
st.write("---")
col1, col2 = st.columns(2)
original_pil_img = load_image(image_file)
with col1:
st.image(original_pil_img, caption="Original Image", use_container_width=True)
with st.spinner("Analyzing the image and predicting segmentation..."):
original_np_img = np.array(original_pil_img)
# 1. Pre-processing for the model
img_gray = convert_one_channel(original_np_img.copy())
img_resized = cv2.resize(img_gray, (512, 512), interpolation=cv2.INTER_LANCZOS4)
img_normalized = np.float32(img_resized / 255.0)
img_input_np = np.reshape(img_normalized, (1, 512, 512, 1))
# 2. Run the prediction using the wrapper model
try:
# DO NOT convert to TensorFlow tensor - pass the numpy array directly
prediction = model.predict(img_input_np)
# Convert the result to a numpy array if it is a tensor
if hasattr(prediction, 'numpy'):
prediction = prediction.numpy()
except Exception as e:
st.error(f"Prediction failed. Error: {e}")
st.code(traceback.format_exc())
st.stop()
# 3. Post-processing of the prediction mask
# Handle the case where prediction might have different dimensions
if len(prediction.shape) == 4:
predicted_mask = prediction[0] # Batch dimension
else:
predicted_mask = prediction
# If the mask has more than 2 dimensions, take the first channel
if len(predicted_mask.shape) > 2:
predicted_mask = predicted_mask[:, :, 0]
# Resize the mask to the original image dimensions
resized_mask = cv2.resize(predicted_mask,
(original_np_img.shape[1], original_np_img.shape[0]),
interpolation=cv2.INTER_LANCZOS4)
# Binarize the mask with Otsu's threshold for a clean result
mask_8bit = (resized_mask * 255).astype(np.uint8)
_, final_mask = cv2.threshold(mask_8bit, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Clean the mask with morphological operations to remove noise
kernel = np.ones((5, 5), dtype=np.uint8)
final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_OPEN, kernel, iterations=2)
final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
# Find contours on the final mask
contours, _ = cv2.findContours(final_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
# Draw the contours on a color version of the original image
img_for_drawing = convert_rgb(original_np_img.copy())
output_image = cv2.drawContours(img_for_drawing, contours, -1, (255, 0, 0), 3) # Red contours
with col2:
st.image(output_image, caption="Image with Segmented Teeth", use_container_width=True)
st.success("Prediction complete!")