File size: 2,443 Bytes
73b3af8
 
 
 
 
 
cb364c2
73b3af8
 
 
 
 
 
 
 
 
cb364c2
73b3af8
 
 
 
 
 
 
 
 
 
 
cb364c2
73b3af8
 
 
 
 
 
 
 
 
 
 
 
 
 
cb364c2
 
 
 
 
 
 
 
 
73b3af8
 
 
cb364c2
 
73b3af8
cb364c2
 
73b3af8
cb364c2
 
 
 
 
1b9d572
cb364c2
 
 
 
 
 
 
73b3af8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import streamlit as st
import tensorflow as tf
import numpy as np
from PIL import Image
import os

# === Fix font/matplotlib warnings ===
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
os.environ["XDG_CACHE_HOME"] = "/tmp"

# === Custom loss and metrics ===
def weighted_dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return 1 - ((2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth))

def iou_metric(y_true, y_pred):
    y_true = tf.cast(y_true > 0.5, tf.float32)
    y_pred = tf.cast(y_pred > 0.5, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection
    return intersection / (union + 1e-6)

def bce_loss(y_true, y_pred):
    return tf.keras.losses.binary_crossentropy(y_true, y_pred)

# === Load Model ===
model_path = "final_model_after_third_iteration_WDL0.07_0.5155/"
@st.cache_resource
def load_model():
    return tf.keras.models.load_model(
        model_path,
        custom_objects={
            "weighted_dice_loss": weighted_dice_loss,
            "iou_metric": iou_metric,
            "bce_loss": bce_loss
        }
    )

model = load_model()

# === Inference Function ===
def run_prediction(image):
    image = image.convert("RGB").resize((512, 512))
    x = np.expand_dims(np.array(image), axis=0)
    y = model.predict(x)[0, :, :, 0]
    y_norm = (y - y.min()) / (y.max() - y.min() + 1e-6)
    mask = (y_norm * 255).astype(np.uint8)
    return Image.fromarray(mask)

# === Streamlit UI ===
st.title("🕳️ Sinkhole Segmentation with EffV2-UNet")

example_dir = "examples"
example_files = sorted([f for f in os.listdir(example_dir) if f.lower().endswith((".jpg", ".png"))])

# Display examples in columns
cols = st.columns(len(example_files))

for i, filename in enumerate(example_files):
    with cols[i]:
        img_path = os.path.join(example_dir, filename)
        example_img = Image.open(img_path)
        st.image(example_img, caption=filename, use_column_width=True)

        if st.button(f"Run on {filename}"):
            st.subheader("Original Image")
            st.image(example_img, use_column_width=True)

            st.subheader("Predicted Mask")
            result = run_prediction(example_img)
            st.image(result, use_column_width=True)