MateuszLis commited on
Commit
bbd3890
·
verified ·
1 Parent(s): f32fc1c

Update saliency_gradio.py

Browse files
Files changed (1) hide show
  1. saliency_gradio.py +34 -39
saliency_gradio.py CHANGED
@@ -10,13 +10,10 @@ hf_dir = snapshot_download(repo_id="alexanderkroner/MSI-Net")
10
 
11
  def get_target_shape(original_shape):
12
  original_aspect_ratio = original_shape[0] / original_shape[1]
13
-
14
  square_mode = abs(original_aspect_ratio - 1.0)
15
  landscape_mode = abs(original_aspect_ratio - 240 / 320)
16
  portrait_mode = abs(original_aspect_ratio - 320 / 240)
17
-
18
  best_mode = min(square_mode, landscape_mode, portrait_mode)
19
-
20
  if best_mode == square_mode:
21
  return (320, 320)
22
  elif best_mode == landscape_mode:
@@ -26,19 +23,13 @@ def get_target_shape(original_shape):
26
 
27
  def preprocess_input(input_image, target_shape):
28
  input_tensor = tf.expand_dims(input_image, axis=0)
29
-
30
- input_tensor = tf.image.resize(
31
- input_tensor, target_shape, preserve_aspect_ratio=True
32
- )
33
-
34
  vertical_padding = target_shape[0] - input_tensor.shape[1]
35
  horizontal_padding = target_shape[1] - input_tensor.shape[2]
36
-
37
  vertical_padding_1 = vertical_padding // 2
38
  vertical_padding_2 = vertical_padding - vertical_padding_1
39
  horizontal_padding_1 = horizontal_padding // 2
40
  horizontal_padding_2 = horizontal_padding - horizontal_padding_1
41
-
42
  input_tensor = tf.pad(
43
  input_tensor,
44
  [
@@ -48,12 +39,7 @@ def preprocess_input(input_image, target_shape):
48
  [0, 0],
49
  ],
50
  )
51
-
52
- return (
53
- input_tensor,
54
- [vertical_padding_1, vertical_padding_2],
55
- [horizontal_padding_1, horizontal_padding_2],
56
- )
57
 
58
  def postprocess_output(output_tensor, vertical_padding, horizontal_padding, original_shape):
59
  output_tensor = output_tensor[
@@ -62,45 +48,54 @@ def postprocess_output(output_tensor, vertical_padding, horizontal_padding, orig
62
  horizontal_padding[0] : output_tensor.shape[2] - horizontal_padding[1],
63
  :,
64
  ]
65
-
66
  output_tensor = tf.image.resize(output_tensor, original_shape)
67
- return output_tensor.numpy().squeeze() # Return grayscale map
68
 
69
  def process_image(input_image):
70
  input_image = np.array(input_image, dtype=np.float32)
71
  original_shape = input_image.shape[:2]
72
  target_shape = get_target_shape(original_shape)
73
-
74
  input_tensor, vertical_padding, horizontal_padding = preprocess_input(input_image, target_shape)
75
  output_tensor = model(input_tensor)["output"]
76
  saliency_gray = postprocess_output(output_tensor, vertical_padding, horizontal_padding, original_shape)
77
  total_saliency = np.sum(saliency_gray)
78
-
79
  saliency_rgb = plt.cm.inferno(saliency_gray)[..., :3]
80
  alpha = 0.9
81
  blended_image = alpha * saliency_rgb + (1 - alpha) * input_image / 255
82
-
83
  return blended_image, f"Total grayscale saliency: {total_saliency:.2f}"
84
 
85
- def predict_two_images(image1, image2):
 
 
 
86
  result1_img, result1_val = process_image(image1)
87
  result2_img, result2_val = process_image(image2)
88
  return result1_img, result1_val, result2_img, result2_val
89
 
90
- iface = gr.Interface(
91
- fn=predict_two_images,
92
- inputs=[
93
- gr.Image(type="pil", label="Input Image 1"),
94
- gr.Image(type="pil", label="Input Image 2"),
95
- ],
96
- outputs=[
97
- gr.Image(type="numpy", label="Saliency Map 1"),
98
- gr.Textbox(label="Grayscale Sum 1"),
99
- gr.Image(type="numpy", label="Saliency Map 2"),
100
- gr.Textbox(label="Grayscale Sum 2"),
101
- ],
102
- title="MSI-Net Saliency Maps for Two Images",
103
- description="Upload two images to compare their saliency maps and total saliency values.",
104
- )
105
-
106
- iface.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def get_target_shape(original_shape):
12
  original_aspect_ratio = original_shape[0] / original_shape[1]
 
13
  square_mode = abs(original_aspect_ratio - 1.0)
14
  landscape_mode = abs(original_aspect_ratio - 240 / 320)
15
  portrait_mode = abs(original_aspect_ratio - 320 / 240)
 
16
  best_mode = min(square_mode, landscape_mode, portrait_mode)
 
17
  if best_mode == square_mode:
18
  return (320, 320)
19
  elif best_mode == landscape_mode:
 
23
 
24
  def preprocess_input(input_image, target_shape):
25
  input_tensor = tf.expand_dims(input_image, axis=0)
26
+ input_tensor = tf.image.resize(input_tensor, target_shape, preserve_aspect_ratio=True)
 
 
 
 
27
  vertical_padding = target_shape[0] - input_tensor.shape[1]
28
  horizontal_padding = target_shape[1] - input_tensor.shape[2]
 
29
  vertical_padding_1 = vertical_padding // 2
30
  vertical_padding_2 = vertical_padding - vertical_padding_1
31
  horizontal_padding_1 = horizontal_padding // 2
32
  horizontal_padding_2 = horizontal_padding - horizontal_padding_1
 
33
  input_tensor = tf.pad(
34
  input_tensor,
35
  [
 
39
  [0, 0],
40
  ],
41
  )
42
+ return input_tensor, [vertical_padding_1, vertical_padding_2], [horizontal_padding_1, horizontal_padding_2]
 
 
 
 
 
43
 
44
  def postprocess_output(output_tensor, vertical_padding, horizontal_padding, original_shape):
45
  output_tensor = output_tensor[
 
48
  horizontal_padding[0] : output_tensor.shape[2] - horizontal_padding[1],
49
  :,
50
  ]
 
51
  output_tensor = tf.image.resize(output_tensor, original_shape)
52
+ return output_tensor.numpy().squeeze()
53
 
54
  def process_image(input_image):
55
  input_image = np.array(input_image, dtype=np.float32)
56
  original_shape = input_image.shape[:2]
57
  target_shape = get_target_shape(original_shape)
 
58
  input_tensor, vertical_padding, horizontal_padding = preprocess_input(input_image, target_shape)
59
  output_tensor = model(input_tensor)["output"]
60
  saliency_gray = postprocess_output(output_tensor, vertical_padding, horizontal_padding, original_shape)
61
  total_saliency = np.sum(saliency_gray)
 
62
  saliency_rgb = plt.cm.inferno(saliency_gray)[..., :3]
63
  alpha = 0.9
64
  blended_image = alpha * saliency_rgb + (1 - alpha) * input_image / 255
 
65
  return blended_image, f"Total grayscale saliency: {total_saliency:.2f}"
66
 
67
+ def predict_single(image):
68
+ return process_image(image)
69
+
70
+ def predict_dual(image1, image2):
71
  result1_img, result1_val = process_image(image1)
72
  result2_img, result2_val = process_image(image2)
73
  return result1_img, result1_val, result2_img, result2_val
74
 
75
+ with gr.Blocks(title="MSI-Net Saliency App") as demo:
76
+ gr.Markdown("## MSI-Net Saliency Map Viewer")
77
+ with gr.Tabs():
78
+ with gr.Tab("Single Image"):
79
+ gr.Markdown("### Upload an image to see its saliency map and total grayscale saliency value.")
80
+ with gr.Row():
81
+ input_image_single = gr.Image(type="pil", label="Input Image")
82
+ with gr.Row():
83
+ output_image_single = gr.Image(type="numpy", label="Saliency Map")
84
+ output_text_single = gr.Textbox(label="Grayscale Sum")
85
+ submit_single = gr.Button("Generate Saliency")
86
+ submit_single.click(fn=predict_single, inputs=input_image_single, outputs=[output_image_single, output_text_single])
87
+
88
+ with gr.Tab("Compare Two Images"):
89
+ gr.Markdown("### Upload two images to compare their saliency maps and grayscale saliency values.")
90
+ with gr.Row():
91
+ input_image1 = gr.Image(type="pil", label="Image 1")
92
+ input_image2 = gr.Image(type="pil", label="Image 2")
93
+ with gr.Row():
94
+ output_image1 = gr.Image(type="numpy", label="Saliency Map 1")
95
+ output_text1 = gr.Textbox(label="Grayscale Sum 1")
96
+ output_image2 = gr.Image(type="numpy", label="Saliency Map 2")
97
+ output_text2 = gr.Textbox(label="Grayscale Sum 2")
98
+ submit_dual = gr.Button("Compare Saliency")
99
+ submit_dual.click(fn=predict_dual, inputs=[input_image1, input_image2], outputs=[output_image1, output_text1, output_image2, output_text2])
100
+
101
+ demo.launch(share=True)