File size: 3,329 Bytes
73b3af8
 
 
 
 
 
e9f8684
73b3af8
 
 
 
 
 
 
 
 
e9f8684
 
73b3af8
 
 
 
 
 
 
 
 
 
 
e9f8684
73b3af8
 
 
 
 
 
 
 
 
 
 
 
 
 
e9f8684
73b3af8
 
862a0c0
 
 
e9f8684
862a0c0
 
 
 
cb364c2
e9f8684
 
 
 
73b3af8
862a0c0
 
e9f8684
 
862a0c0
e9f8684
 
862a0c0
e9f8684
862a0c0
 
 
e9f8684
862a0c0
 
 
 
 
e9f8684
 
 
 
 
1b9d572
e9f8684
cb364c2
862a0c0
 
 
 
e9f8684
862a0c0
e9f8684
862a0c0
 
73b3af8
862a0c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import streamlit as st
import tensorflow as tf
import numpy as np
from PIL import Image
import os

# === Fix font/matplotlib warnings for Hugging Face ===
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
os.environ["XDG_CACHE_HOME"] = "/tmp"

# === Custom loss and metrics ===
def weighted_dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return 1 - ((2. * intersection + smooth) /
                (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth))

def iou_metric(y_true, y_pred):
    y_true = tf.cast(y_true > 0.5, tf.float32)
    y_pred = tf.cast(y_pred > 0.5, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection
    return intersection / (union + 1e-6)

def bce_loss(y_true, y_pred):
    return tf.keras.losses.binary_crossentropy(y_true, y_pred)

# === Load model ===
model_path = "final_model_after_third_iteration_WDL0.07_0.5155/"
@st.cache_resource
def load_model():
    return tf.keras.models.load_model(
        model_path,
        custom_objects={
            "weighted_dice_loss": weighted_dice_loss,
            "iou_metric": iou_metric,
            "bce_loss": bce_loss
        }
    )

model = load_model()

# === Title ===
st.title("🕳️ Sinkhole Segmentation with EffV2-UNet")

# === Confidence threshold and predict trigger ===
st.sidebar.header("Segmentation Settings")
threshold = st.sidebar.slider("Confidence Threshold", 0.0, 1.0, 0.5, step=0.01)

# === Image input section ===
uploaded_image = st.file_uploader("📤 Upload an image", type=["png", "jpg", "jpeg", "tif", "tiff"])

# === Example selector with preview ===
example_dir = "examples"
example_files = sorted([
    f for f in os.listdir(example_dir)
    if f.lower().endswith((".jpg", ".jpeg", ".png", ".tif", ".tiff"))
])

selected_example_path = None

if example_files:
    st.subheader("🖼️ Try with an Example Image")
    cols = st.columns(min(len(example_files), 4))
    for i, file in enumerate(example_files):
        img_path = os.path.join(example_dir, file)
        img_preview = Image.open(img_path).convert("RGB").resize((128, 128))
        with cols[i % len(cols)]:
            st.image(img_preview, caption=file, use_column_width=True)
            if st.button(f"Use {file}", key=file):
                selected_example_path = img_path

# === Set image to process ===
if selected_example_path:
    uploaded_image = selected_example_path

# === Run prediction if button clicked ===
if uploaded_image:
    if isinstance(uploaded_image, str):
        image = Image.open(uploaded_image).convert("RGB")
    else:
        image = Image.open(uploaded_image).convert("RGB")

    st.image(image, caption="Input Image", use_column_width=True)

    if st.button("Run Segmentation"):
        resized = image.resize((512, 512))
        x = np.expand_dims(np.array(resized), axis=0)
        y = model.predict(x)[0, :, :, 0]

        st.text(f"Prediction min/max: {y.min():.5f} / {y.max():.5f}")

        mask_bin = (y > threshold).astype(np.uint8) * 255
        mask_image = Image.fromarray(mask_bin)

        st.image(mask_image, caption=f"Segmentation Mask (Threshold = {threshold:.2f})", use_column_width=True)