osherr commited on
Commit
e9f8684
·
verified ·
1 Parent(s): cb364c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -26
app.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
  from PIL import Image
5
  import os
6
 
7
- # === Fix font/matplotlib warnings ===
8
  os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
9
  os.environ["XDG_CACHE_HOME"] = "/tmp"
10
 
@@ -14,7 +14,8 @@ def weighted_dice_loss(y_true, y_pred):
14
  y_true_f = tf.reshape(y_true, [-1])
15
  y_pred_f = tf.reshape(y_pred, [-1])
16
  intersection = tf.reduce_sum(y_true_f * y_pred_f)
17
- return 1 - ((2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth))
 
18
 
19
  def iou_metric(y_true, y_pred):
20
  y_true = tf.cast(y_true > 0.5, tf.float32)
@@ -26,7 +27,7 @@ def iou_metric(y_true, y_pred):
26
  def bce_loss(y_true, y_pred):
27
  return tf.keras.losses.binary_crossentropy(y_true, y_pred)
28
 
29
- # === Load Model ===
30
  model_path = "final_model_after_third_iteration_WDL0.07_0.5155/"
31
  @st.cache_resource
32
  def load_model():
@@ -41,35 +42,51 @@ def load_model():
41
 
42
  model = load_model()
43
 
44
- # === Inference Function ===
45
- def run_prediction(image):
46
- image = image.convert("RGB").resize((512, 512))
47
- x = np.expand_dims(np.array(image), axis=0)
48
- y = model.predict(x)[0, :, :, 0]
49
- y_norm = (y - y.min()) / (y.max() - y.min() + 1e-6)
50
- mask = (y_norm * 255).astype(np.uint8)
51
- return Image.fromarray(mask)
52
-
53
- # === Streamlit UI ===
54
  st.title("🕳️ Sinkhole Segmentation with EffV2-UNet")
55
 
 
 
 
 
56
  example_dir = "examples"
57
- example_files = sorted([f for f in os.listdir(example_dir) if f.lower().endswith((".jpg", ".png"))])
 
 
 
58
 
59
- # Display examples in columns
60
- cols = st.columns(len(example_files))
 
61
 
62
- for i, filename in enumerate(example_files):
63
- with cols[i]:
64
- img_path = os.path.join(example_dir, filename)
65
  example_img = Image.open(img_path)
66
- st.image(example_img, caption=filename, use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- if st.button(f"Run on {filename}"):
69
- st.subheader("Original Image")
70
- st.image(example_img, use_column_width=True)
71
 
72
- st.subheader("Predicted Mask")
73
- result = run_prediction(example_img)
74
- st.image(result, use_column_width=True)
 
 
 
 
 
 
 
 
75
 
 
 
4
  from PIL import Image
5
  import os
6
 
7
+ # === Fix font/matplotlib warnings for Hugging Face ===
8
  os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
9
  os.environ["XDG_CACHE_HOME"] = "/tmp"
10
 
 
14
  y_true_f = tf.reshape(y_true, [-1])
15
  y_pred_f = tf.reshape(y_pred, [-1])
16
  intersection = tf.reduce_sum(y_true_f * y_pred_f)
17
+ return 1 - ((2. * intersection + smooth) /
18
+ (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth))
19
 
20
  def iou_metric(y_true, y_pred):
21
  y_true = tf.cast(y_true > 0.5, tf.float32)
 
27
  def bce_loss(y_true, y_pred):
28
  return tf.keras.losses.binary_crossentropy(y_true, y_pred)
29
 
30
+ # === Load model ===
31
  model_path = "final_model_after_third_iteration_WDL0.07_0.5155/"
32
  @st.cache_resource
33
  def load_model():
 
42
 
43
  model = load_model()
44
 
45
+ # === Title ===
 
 
 
 
 
 
 
 
 
46
  st.title("🕳️ Sinkhole Segmentation with EffV2-UNet")
47
 
48
+ # === File uploader ===
49
+ uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg", "tif", "tiff"])
50
+
51
+ # === Example selector ===
52
  example_dir = "examples"
53
+ example_files = sorted([
54
+ f for f in os.listdir(example_dir)
55
+ if f.lower().endswith((".jpg", ".jpeg", ".png", ".tif", ".tiff"))
56
+ ])
57
 
58
+ if example_files:
59
+ st.subheader("🖼️ Try with an Example Image")
60
+ cols = st.columns(min(len(example_files), 4)) # up to 4 per row
61
 
62
+ for i, file in enumerate(example_files):
63
+ img_path = os.path.join(example_dir, file)
 
64
  example_img = Image.open(img_path)
65
+ with cols[i % len(cols)]:
66
+ if st.button(file, key=file):
67
+ uploaded_image = img_path # simulate upload
68
+ image = example_img.convert("RGB")
69
+ st.image(image, caption=f"Example: {file}", use_column_width=True)
70
+
71
+ # === Prediction ===
72
+ if uploaded_image:
73
+ if isinstance(uploaded_image, str):
74
+ image = Image.open(uploaded_image).convert("RGB")
75
+ else:
76
+ image = Image.open(uploaded_image).convert("RGB")
77
 
78
+ st.image(image, caption="Input Image", use_column_width=True)
 
 
79
 
80
+ # Preprocess and predict
81
+ resized = image.resize((512, 512))
82
+ x = np.expand_dims(np.array(resized), axis=0)
83
+ y = model.predict(x)[0, :, :, 0]
84
+
85
+ st.text(f"Prediction min/max: {y.min():.5f} / {y.max():.5f}")
86
+
87
+ # Normalize output
88
+ y_norm = (y - y.min()) / (y.max() - y.min() + 1e-6)
89
+ mask = (y_norm * 255).astype(np.uint8)
90
+ result = Image.fromarray(mask)
91
 
92
+ st.image(result, caption="Predicted Segmentation", use_column_width=True)