MateuszLis commited on
Commit
356875c
·
verified ·
1 Parent(s): cc46213

Update saliency_gradio.py

Browse files
Files changed (1) hide show
  1. saliency_gradio.py +27 -14
saliency_gradio.py CHANGED
@@ -51,26 +51,36 @@ def postprocess_output(output_tensor, vertical_padding, horizontal_padding, orig
51
  output_tensor = tf.image.resize(output_tensor, original_shape)
52
  return output_tensor.numpy().squeeze()
53
 
54
- def process_image(input_image, threshold=0.5):
55
- input_image = np.array(input_image, dtype=np.float32)
56
- original_shape = input_image.shape[:2]
57
  target_shape = get_target_shape(original_shape)
58
 
59
- input_tensor, vertical_padding, horizontal_padding = preprocess_input(input_image, target_shape)
60
  output_tensor = model(input_tensor)["output"]
61
  saliency_gray = postprocess_output(output_tensor, vertical_padding, horizontal_padding, original_shape)
62
 
 
63
  total_saliency = np.sum(saliency_gray)
64
- above_threshold_saliency = np.sum(saliency_gray[saliency_gray > threshold])
65
- ratio = above_threshold_saliency / total_saliency if total_saliency > 0 else 0.0
66
 
 
 
 
 
 
 
 
 
 
 
 
67
  saliency_rgb = plt.cm.inferno(saliency_gray)[..., :3]
68
  alpha = 0.9
69
- blended_image = alpha * saliency_rgb + (1 - alpha) * input_image / 255
70
 
71
- summary_text = f"{ratio:.4f}"
72
 
73
- return blended_image, summary_text
74
 
75
  def predict_single(image):
76
  return process_image(image)
@@ -84,15 +94,18 @@ with gr.Blocks(title="MSI-Net Saliency App") as demo:
84
  gr.Markdown("## MSI-Net Saliency Map Viewer")
85
  with gr.Tabs():
86
  with gr.Tab("Single Image"):
87
- gr.Markdown("### Upload an image to see its saliency map and total grayscale saliency value.")
88
  with gr.Row():
89
- input_image_single = gr.Image(type="pil", label="Input Image")
90
  with gr.Row():
91
  output_image_single = gr.Image(type="numpy", label="Saliency Map")
92
- output_text_single = gr.Textbox(label="Grayscale Sum")
93
  submit_single = gr.Button("Generate Saliency")
94
- submit_single.click(fn=predict_single, inputs=input_image_single, outputs=[output_image_single, output_text_single])
95
-
 
 
 
96
  with gr.Tab("Compare Two Images"):
97
  gr.Markdown("### Upload two images to compare their saliency maps and grayscale saliency values.")
98
  with gr.Row():
 
51
  output_tensor = tf.image.resize(output_tensor, original_shape)
52
  return output_tensor.numpy().squeeze()
53
 
54
+ def process_image_with_bbox(input_image, bbox, threshold=0.0):
55
+ input_image_np = np.array(input_image, dtype=np.float32)
56
+ original_shape = input_image_np.shape[:2]
57
  target_shape = get_target_shape(original_shape)
58
 
59
+ input_tensor, vertical_padding, horizontal_padding = preprocess_input(input_image_np, target_shape)
60
  output_tensor = model(input_tensor)["output"]
61
  saliency_gray = postprocess_output(output_tensor, vertical_padding, horizontal_padding, original_shape)
62
 
63
+ # Total saliency
64
  total_saliency = np.sum(saliency_gray)
 
 
65
 
66
+ # Bounding box: bbox = [x_min, y_min, x_max, y_max] in image pixel coordinates
67
+ if bbox is not None:
68
+ x_min, y_min, x_max, y_max = map(int, bbox)
69
+ saliency_crop = saliency_gray[y_min:y_max, x_min:x_max]
70
+ bbox_sum = np.sum(saliency_crop)
71
+ bbox_ratio = bbox_sum / total_saliency if total_saliency > 0 else 0.0
72
+ else:
73
+ bbox_sum = 0
74
+ bbox_ratio = 0.0
75
+
76
+ # Heatmap overlay
77
  saliency_rgb = plt.cm.inferno(saliency_gray)[..., :3]
78
  alpha = 0.9
79
+ blended_image = alpha * saliency_rgb + (1 - alpha) * input_image_np / 255
80
 
81
+ summary = f"{bbox_ratio:.4f}"
82
 
83
+ return blended_image, summary
84
 
85
  def predict_single(image):
86
  return process_image(image)
 
94
  gr.Markdown("## MSI-Net Saliency Map Viewer")
95
  with gr.Tabs():
96
  with gr.Tab("Single Image"):
97
+ gr.Markdown("### Upload an image and draw a bounding box to measure saliency inside it.")
98
  with gr.Row():
99
+ input_image_single = gr.Image(type="pil", tool="select", label="Input Image with Bounding Box")
100
  with gr.Row():
101
  output_image_single = gr.Image(type="numpy", label="Saliency Map")
102
+ output_text_single = gr.Textbox(label="Saliency Stats")
103
  submit_single = gr.Button("Generate Saliency")
104
+ submit_single.click(
105
+ fn=process_image_with_bbox,
106
+ inputs=[input_image_single, input_image_single.select_region],
107
+ outputs=[output_image_single, output_text_single],
108
+ )
109
  with gr.Tab("Compare Two Images"):
110
  gr.Markdown("### Upload two images to compare their saliency maps and grayscale saliency values.")
111
  with gr.Row():