File size: 3,301 Bytes
73b3af8
 
 
 
 
 
8277b50
73b3af8
 
 
8277b50
73b3af8
 
 
 
 
8277b50
e9f8684
73b3af8
 
 
 
 
 
 
 
 
 
 
8277b50
 
73b3af8
 
 
8277b50
73b3af8
 
 
 
 
 
 
 
 
8277b50
6c2ed67
73b3af8
8277b50
 
 
e9f8684
8277b50
 
 
 
cb364c2
e9f8684
 
 
 
73b3af8
e9f8684
8277b50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fe897d
8277b50
 
 
 
7fe897d
8277b50
 
 
 
7fe897d
8277b50
 
 
7fe897d
8277b50
e9f8684
8277b50
 
 
e9f8684
8277b50
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import streamlit as st
import tensorflow as tf
import numpy as np
from PIL import Image
import os

# === Fix font/matplotlib warnings for Hugging Face ===
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
os.environ["XDG_CACHE_HOME"] = "/tmp"

# === Custom loss and metrics ===
def weighted_dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return 1 - ((2. * intersection + smooth) / 
                (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth))

def iou_metric(y_true, y_pred):
    y_true = tf.cast(y_true > 0.5, tf.float32)
    y_pred = tf.cast(y_pred > 0.5, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection
    return intersection / (union + 1e-6)

def bce_loss(y_true, y_pred):
    return tf.keras.losses.binary_crossentropy(y_true, y_pred)

# === Load model ===
model_path = "final_model_after_third_iteration_WDL0.07_0.5155/"
@st.cache_resource
def load_model():
    return tf.keras.models.load_model(
        model_path,
        custom_objects={
            "weighted_dice_loss": weighted_dice_loss,
            "iou_metric": iou_metric,
            "bce_loss": bce_loss
        }
    )

model = load_model()

# === Title ===
st.title("🕳️ SinkSAM-Net - Self Supervised Sinkhole segmentation")

# === Session state for selected example ===
if "selected_example" not in st.session_state:
    st.session_state.selected_example = None

# === File uploader ===
uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg", "tif", "tiff"])

# === Example selector ===
example_dir = "examples"
example_files = sorted([
    f for f in os.listdir(example_dir)
    if f.lower().endswith((".jpg", ".jpeg", ".png", ".tif", ".tiff"))
])

if example_files:
    st.subheader("🖼️ Try with an Example Image")
    cols = st.columns(min(len(example_files), 4))

    for i, file in enumerate(example_files):
        with cols[i % len(cols)]:
            img_path = os.path.join(example_dir, file)
            example_img = Image.open(img_path)
            st.image(example_img, caption=file, use_container_width=True)
            if st.button(f"Run Segmentation", key=file):
                st.session_state.selected_example = img_path

# === Determine active image ===
active_image = None
if uploaded_image is not None:
    active_image = uploaded_image
elif st.session_state.selected_example is not None:
    active_image = st.session_state.selected_example

# === Confidence threshold slider ===
threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5, step=0.01)

# === Prediction ===
if active_image:
    image = Image.open(active_image).convert("RGB")
    st.image(image, caption="Input Image", use_container_width=True)

    resized = image.resize((512, 512))
    x = np.expand_dims(np.array(resized), axis=0)
    y = model.predict(x)[0, :, :, 0]

    st.text(f"Prediction min/max: {y.min():.5f} / {y.max():.5f}")

    # Apply threshold
    mask_bin = (y > threshold).astype(np.uint8) * 255
    mask_image = Image.fromarray(mask_bin)

    st.image(mask_image, caption=f"Segmentation (Threshold = {threshold:.2f})", use_container_width=True)