File size: 5,997 Bytes
03e7460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
from datetime import datetime
from pathlib import Path
import torch
import folium
import streamlit as st
from loguru import logger
from tqdm import tqdm
from streamlit_folium import st_folium
from transformers import SegformerForSemanticSegmentation
from lib.folium import (
    get_clean_rendering_container,
    create_map,
    process_raster_and_overlays,
)
import streamlit.components.v1 as components

# Page configs
st.set_page_config(page_title="GrowSeg Demo", page_icon="πŸ‡", layout="wide")

# BUGFIX (https://discuss.streamlit.io/t/message-error-about-torch/90886/6)
torch.classes.__path__ = []

# Interoperability with tqdm (https://loguru.readthedocs.io/en/stable/resources/recipes.html#interoperability-with-tqdm-iterations)
logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, format="<green>{message}</green>")

@st.cache_resource
def load_model(hf_path='links-ads/gaia-growseg'):
    # logger.info(f'Loading GAIA GRowSeg on {device}...')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SegformerForSemanticSegmentation.from_pretrained(
        hf_path,
        num_labels=1,
        num_channels=3,
        id2label={1: 'vine'},
        label2id={'vine': 1},
        token=os.getenv('hf_read_access_token')
    )
    return model.to(device).eval()

# Load GAIA GRowSeg model
model = load_model()


def change_key():
    st.session_state["key_map"] = str(datetime.now())


# Create selection menu
container_predictions = st.container(border=True)
with container_predictions:
    col1, col2 = st.columns([0.3, 0.7])
    with col1:
        # raster_path = st.text_input(
        #     "Enter the path to your local file: ",
        #     key="raster_path_block",
        # )
        # raster_path = st.file_uploader(
        #     "Upload a raster file",
        #     type=["tif", "tiff"],
        #     key="raster_path_block",
        # )
        precomputed_map_path = None
        raster_path = None
        raster_selection = st.selectbox(
            "Select an example or your own raster...",
            options=[
                "Italy",
                "Portugal",
                "Spain",
                "Upload file...",
            ],
            key="raster_selection_block",
            index=None,
            placeholder="Choose an example or upload your own raster",
        )
        if raster_selection == "Italy":
            st.markdown("At this stage, only Portugal is available due to the WebSocket payload limit.")
            # TODO GEOSERVER
            #precomputed_map_path = "data/italy_2022-06-13_cropped.html"
        elif raster_selection == "Portugal":
            precomputed_map_path = "data/portugal_2023-08-01.html"
        elif raster_selection == "Spain":
            st.markdown("At this stage, only Portugal is available due to the WebSocket payload limit.")
            #precomputed_map_path = "data/spain_2022-07-29_cropped.html"
        elif raster_selection == "Upload file...":
            uploaded_file = st.file_uploader(
                "Upload a raster file",
                type=["tif"],
                key="uploaded_file_block",
            )
            if uploaded_file is not None:
                fn = Path(uploaded_file.name).name
                print(fn)
                raster_path = os.path.join("temp", fn)
                with open(raster_path, "wb") as f:
                    f.write(uploaded_file.getbuffer())

    is_raster_path_selected = raster_path is not None
    is_precomputed_map_selected = precomputed_map_path is not None

    with col2:
        with st.container():
            st.write("######")
            with st.expander("More info on the model"):
                st.write("""
                    Under the hood, this model is a SegFormer-b5, trained on
                    UAV-acquired vineyard orthoimages and their ground-truth
                    delineation masks. Paper will be available soon. Stay tuned!
                """)

    if not is_precomputed_map_selected and is_raster_path_selected:
        progress_bar = st.progress(0, text="Begin processing...")
        # Process raster and get overlays
        overlays = process_raster_and_overlays(raster_path, model, _progress_bar=progress_bar)
        #progress_bar.empty()


#container = get_clean_rendering_container(raster_path)
container = st.empty()

# draw map
interactive_map = create_map()

if is_raster_path_selected:
    # Add overlays to map
    for overlay in overlays:
        overlay.add_to(interactive_map)

with container.form(key="form1"):

    if is_precomputed_map_selected:
        # Load precomputed map
        # interactive_map = folium.Map(location=[35, -10], zoom_start=6)
        # folium.IFrame(
        #     precomputed_map_path,
        #     width=1000,
        #     height=500,
        # ).add_to(interactive_map)

        with open(precomputed_map_path, 'r') as f:
            html_content = f.read()
        interactive_map = components.html(html_content, height=500)


    else:

        if is_raster_path_selected:
            # Center map on overlays
            bounds = overlays[0].get_bounds()
            interactive_map.fit_bounds(bounds)
        else:
            # Center map on Europe
            interactive_map.fit_bounds([[35, -10], [60, 40]])

        # Add Layer Control (first remove existing one)
        for key, child in list(interactive_map._children.items()):
            if isinstance(child, folium.map.LayerControl):
                del interactive_map._children[key]
        folium.LayerControl().add_to(interactive_map)

        # Folium Map component
        output_map = st_folium(
            interactive_map,
            width=None,
            height=500,
            returned_objects=["all_drawings"],
            key=st.session_state.get("key_map", "key_map"), # This is a workaround to force the map to recenter
        )

    # Recenter map
    submit = st.form_submit_button("Recenter map")