import streamlit as st import tensorflow as tf import numpy as np from PIL import Image import os # === Fix font/matplotlib warnings === os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" os.environ["XDG_CACHE_HOME"] = "/tmp" # === Custom loss and metrics === def weighted_dice_loss(y_true, y_pred): smooth = 1e-6 y_true_f = tf.reshape(y_true, [-1]) y_pred_f = tf.reshape(y_pred, [-1]) intersection = tf.reduce_sum(y_true_f * y_pred_f) return 1 - ((2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)) def iou_metric(y_true, y_pred): y_true = tf.cast(y_true > 0.5, tf.float32) y_pred = tf.cast(y_pred > 0.5, tf.float32) intersection = tf.reduce_sum(y_true * y_pred) union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection return intersection / (union + 1e-6) def bce_loss(y_true, y_pred): return tf.keras.losses.binary_crossentropy(y_true, y_pred) # === Load Model === model_path = "final_model_after_third_iteration_WDL0.07_0.5155/" @st.cache_resource def load_model(): return tf.keras.models.load_model( model_path, custom_objects={ "weighted_dice_loss": weighted_dice_loss, "iou_metric": iou_metric, "bce_loss": bce_loss } ) model = load_model() # === Inference Function === def run_prediction(image): image = image.convert("RGB").resize((512, 512)) x = np.expand_dims(np.array(image), axis=0) y = model.predict(x)[0, :, :, 0] y_norm = (y - y.min()) / (y.max() - y.min() + 1e-6) mask = (y_norm * 255).astype(np.uint8) return Image.fromarray(mask) # === Streamlit UI === st.title("🕳️ Sinkhole Segmentation with EffV2-UNet") example_dir = "examples" example_files = sorted([f for f in os.listdir(example_dir) if f.lower().endswith((".jpg", ".png"))]) # Display examples in columns cols = st.columns(len(example_files)) for i, filename in enumerate(example_files): with cols[i]: img_path = os.path.join(example_dir, filename) example_img = Image.open(img_path) st.image(example_img, caption=filename, use_column_width=True) if st.button(f"Run on {filename}"): st.subheader("Original Image") st.image(example_img, use_column_width=True) st.subheader("Predicted Mask") result = run_prediction(example_img) st.image(result, use_column_width=True)