File size: 3,493 Bytes
2f3546c
 
 
62cf953
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe67183
2f3546c
fe67183
62cf953
fe67183
 
62cf953
fe67183
 
 
 
2f3546c
fe67183
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import streamlit as st
from streamlit_back_camera_input import back_camera_input

import matplotlib.pyplot as plt
import tensorflow as tf

loaded_model = tf.saved_model.load("model/")
loaded_model = loaded_model.signatures["serving_default"]

def get_target_shape(original_shape):
    original_aspect_ratio = original_shape[0] / original_shape[1]

    square_mode = abs(original_aspect_ratio - 1.0)
    landscape_mode = abs(original_aspect_ratio - 240 / 320)
    portrait_mode = abs(original_aspect_ratio - 320 / 240)

    best_mode = min(square_mode, landscape_mode, portrait_mode)

    if best_mode == square_mode:
        target_shape = (320, 320)
    elif best_mode == landscape_mode:
        target_shape = (240, 320)
    else:
        target_shape = (320, 240)

    return target_shape


def preprocess_input(input_image, target_shape):
    input_tensor = tf.expand_dims(input_image, axis=0)

    input_tensor = tf.image.resize(
        input_tensor, target_shape, preserve_aspect_ratio=True
    )

    vertical_padding = target_shape[0] - input_tensor.shape[1]
    horizontal_padding = target_shape[1] - input_tensor.shape[2]

    vertical_padding_1 = vertical_padding // 2
    vertical_padding_2 = vertical_padding - vertical_padding_1

    horizontal_padding_1 = horizontal_padding // 2
    horizontal_padding_2 = horizontal_padding - horizontal_padding_1

    input_tensor = tf.pad(
        input_tensor,
        [
            [0, 0],
            [vertical_padding_1, vertical_padding_2],
            [horizontal_padding_1, horizontal_padding_2],
            [0, 0],
        ],
    )

    return (
        input_tensor,
        [vertical_padding_1, vertical_padding_2],
        [horizontal_padding_1, horizontal_padding_2],
    )


def postprocess_output(
    output_tensor, vertical_padding, horizontal_padding, original_shape
):
    output_tensor = output_tensor[
        :,
        vertical_padding[0] : output_tensor.shape[1] - vertical_padding[1],
        horizontal_padding[0] : output_tensor.shape[2] - horizontal_padding[1],
        :,
    ]

    output_tensor = tf.image.resize(output_tensor, original_shape)

    output_array = output_tensor.numpy().squeeze()
    output_array = plt.cm.inferno(output_array)[..., :3]

    return output_array


def compute_saliency(input_image, alpha=0.65):
    if input_image is not None:
        original_shape = input_image.shape[:2]
        target_shape = get_target_shape(original_shape)

        input_tensor, vertical_padding, horizontal_padding = preprocess_input(
            input_image, target_shape
        )

        saliency_map = loaded_model(input_tensor)["output"]

        saliency_map = postprocess_output(
            saliency_map, vertical_padding, horizontal_padding, original_shape
        )

        blended_image = alpha * saliency_map + (1 - alpha) * input_image / 255

        return blended_image


st.title("Visual Saliency Prediction")

col1, col2, col3 = st.columns([1, 1, 1])

with col1:
    input_image = st.file_uploader("Upload Input Image", type=["jpg", "jpeg", "png"])

with col2:
    output_image = back_camera_input()
    if image:
        st.image(image)

with col3:
    btn = st.button("Compute")

if btn:
    if input_image is not None:
        # Perform computation
        saliency_map = compute_saliency(input_image)
        
        # Display output
        output_image.image(saliency_map, caption="Saliency Map", use_column_width=True)
    else:
        st.warning("Please upload an image.")