Spaces:
Sleeping
Sleeping
File size: 3,329 Bytes
73b3af8 e9f8684 73b3af8 e9f8684 73b3af8 e9f8684 73b3af8 e9f8684 73b3af8 862a0c0 e9f8684 862a0c0 cb364c2 e9f8684 73b3af8 862a0c0 e9f8684 862a0c0 e9f8684 862a0c0 e9f8684 862a0c0 e9f8684 862a0c0 e9f8684 1b9d572 e9f8684 cb364c2 862a0c0 e9f8684 862a0c0 e9f8684 862a0c0 73b3af8 862a0c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import streamlit as st
import tensorflow as tf
import numpy as np
from PIL import Image
import os
# === Fix font/matplotlib warnings for Hugging Face ===
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
os.environ["XDG_CACHE_HOME"] = "/tmp"
# === Custom loss and metrics ===
def weighted_dice_loss(y_true, y_pred):
smooth = 1e-6
y_true_f = tf.reshape(y_true, [-1])
y_pred_f = tf.reshape(y_pred, [-1])
intersection = tf.reduce_sum(y_true_f * y_pred_f)
return 1 - ((2. * intersection + smooth) /
(tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth))
def iou_metric(y_true, y_pred):
y_true = tf.cast(y_true > 0.5, tf.float32)
y_pred = tf.cast(y_pred > 0.5, tf.float32)
intersection = tf.reduce_sum(y_true * y_pred)
union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection
return intersection / (union + 1e-6)
def bce_loss(y_true, y_pred):
return tf.keras.losses.binary_crossentropy(y_true, y_pred)
# === Load model ===
model_path = "final_model_after_third_iteration_WDL0.07_0.5155/"
@st.cache_resource
def load_model():
return tf.keras.models.load_model(
model_path,
custom_objects={
"weighted_dice_loss": weighted_dice_loss,
"iou_metric": iou_metric,
"bce_loss": bce_loss
}
)
model = load_model()
# === Title ===
st.title("🕳️ Sinkhole Segmentation with EffV2-UNet")
# === Confidence threshold and predict trigger ===
st.sidebar.header("Segmentation Settings")
threshold = st.sidebar.slider("Confidence Threshold", 0.0, 1.0, 0.5, step=0.01)
# === Image input section ===
uploaded_image = st.file_uploader("📤 Upload an image", type=["png", "jpg", "jpeg", "tif", "tiff"])
# === Example selector with preview ===
example_dir = "examples"
example_files = sorted([
f for f in os.listdir(example_dir)
if f.lower().endswith((".jpg", ".jpeg", ".png", ".tif", ".tiff"))
])
selected_example_path = None
if example_files:
st.subheader("🖼️ Try with an Example Image")
cols = st.columns(min(len(example_files), 4))
for i, file in enumerate(example_files):
img_path = os.path.join(example_dir, file)
img_preview = Image.open(img_path).convert("RGB").resize((128, 128))
with cols[i % len(cols)]:
st.image(img_preview, caption=file, use_column_width=True)
if st.button(f"Use {file}", key=file):
selected_example_path = img_path
# === Set image to process ===
if selected_example_path:
uploaded_image = selected_example_path
# === Run prediction if button clicked ===
if uploaded_image:
if isinstance(uploaded_image, str):
image = Image.open(uploaded_image).convert("RGB")
else:
image = Image.open(uploaded_image).convert("RGB")
st.image(image, caption="Input Image", use_column_width=True)
if st.button("Run Segmentation"):
resized = image.resize((512, 512))
x = np.expand_dims(np.array(resized), axis=0)
y = model.predict(x)[0, :, :, 0]
st.text(f"Prediction min/max: {y.min():.5f} / {y.max():.5f}")
mask_bin = (y > threshold).astype(np.uint8) * 255
mask_image = Image.fromarray(mask_bin)
st.image(mask_image, caption=f"Segmentation Mask (Threshold = {threshold:.2f})", use_column_width=True)
|