Spaces:
Running
Running
File size: 3,301 Bytes
73b3af8 8277b50 73b3af8 8277b50 73b3af8 8277b50 e9f8684 73b3af8 8277b50 73b3af8 8277b50 73b3af8 8277b50 6c2ed67 73b3af8 8277b50 e9f8684 8277b50 cb364c2 e9f8684 73b3af8 e9f8684 8277b50 7fe897d 8277b50 7fe897d 8277b50 7fe897d 8277b50 7fe897d 8277b50 e9f8684 8277b50 e9f8684 8277b50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import streamlit as st
import tensorflow as tf
import numpy as np
from PIL import Image
import os
# === Fix font/matplotlib warnings for Hugging Face ===
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
os.environ["XDG_CACHE_HOME"] = "/tmp"
# === Custom loss and metrics ===
def weighted_dice_loss(y_true, y_pred):
smooth = 1e-6
y_true_f = tf.reshape(y_true, [-1])
y_pred_f = tf.reshape(y_pred, [-1])
intersection = tf.reduce_sum(y_true_f * y_pred_f)
return 1 - ((2. * intersection + smooth) /
(tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth))
def iou_metric(y_true, y_pred):
y_true = tf.cast(y_true > 0.5, tf.float32)
y_pred = tf.cast(y_pred > 0.5, tf.float32)
intersection = tf.reduce_sum(y_true * y_pred)
union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection
return intersection / (union + 1e-6)
def bce_loss(y_true, y_pred):
return tf.keras.losses.binary_crossentropy(y_true, y_pred)
# === Load model ===
model_path = "final_model_after_third_iteration_WDL0.07_0.5155/"
@st.cache_resource
def load_model():
return tf.keras.models.load_model(
model_path,
custom_objects={
"weighted_dice_loss": weighted_dice_loss,
"iou_metric": iou_metric,
"bce_loss": bce_loss
}
)
model = load_model()
# === Title ===
st.title("🕳️ SinkSAM-Net - Self Supervised Sinkhole segmentation")
# === Session state for selected example ===
if "selected_example" not in st.session_state:
st.session_state.selected_example = None
# === File uploader ===
uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg", "tif", "tiff"])
# === Example selector ===
example_dir = "examples"
example_files = sorted([
f for f in os.listdir(example_dir)
if f.lower().endswith((".jpg", ".jpeg", ".png", ".tif", ".tiff"))
])
if example_files:
st.subheader("🖼️ Try with an Example Image")
cols = st.columns(min(len(example_files), 4))
for i, file in enumerate(example_files):
with cols[i % len(cols)]:
img_path = os.path.join(example_dir, file)
example_img = Image.open(img_path)
st.image(example_img, caption=file, use_container_width=True)
if st.button(f"Run Segmentation", key=file):
st.session_state.selected_example = img_path
# === Determine active image ===
active_image = None
if uploaded_image is not None:
active_image = uploaded_image
elif st.session_state.selected_example is not None:
active_image = st.session_state.selected_example
# === Confidence threshold slider ===
threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5, step=0.01)
# === Prediction ===
if active_image:
image = Image.open(active_image).convert("RGB")
st.image(image, caption="Input Image", use_container_width=True)
resized = image.resize((512, 512))
x = np.expand_dims(np.array(resized), axis=0)
y = model.predict(x)[0, :, :, 0]
st.text(f"Prediction min/max: {y.min():.5f} / {y.max():.5f}")
# Apply threshold
mask_bin = (y > threshold).astype(np.uint8) * 255
mask_image = Image.fromarray(mask_bin)
st.image(mask_image, caption=f"Segmentation (Threshold = {threshold:.2f})", use_container_width=True)
|