File size: 9,759 Bytes
46818a5
03e7460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
#import branca
import folium
import streamlit as st
from loguru import logger
import rioxarray as rxr
import numpy as np
import xarray as xr
import torch
from .utils import compute_mask, compute_vndvi, compute_vdi
import os



@st.cache_resource
def create_map(location=[41.9099533, 12.3711879], zoom_start=5, crs=3857, max_zoom=23):
    """Create a folium map with OpenStreetMap tiles and optional Esri.WorldImagery basemap."""
    if isinstance(crs, int):
        crs = f"EPSG{crs}"
    assert crs in ["EPSG3857"], f"Only EPSG:3857 supported for now. Got {crs}."
    
    m = folium.Map(
        location=location,
        zoom_start=zoom_start,
        crs=crs,
        max_zoom=max_zoom,
        tiles="OpenStreetMap",  # Esri.WorldImagery
        attributionControl=False,
        prefer_canvas=True,
    )

    # Add Esri.WorldImagery as optional basemap (radio button)
    folium.TileLayer(
        tiles="Esri.WorldImagery",
        show=False,
        overlay=False,
        control=True,
    ).add_to(m)

    return m



def get_clean_rendering_container(app_state: str):
    """Makes sure we can render from a clean slate on state changes."""
    slot_in_use = st.session_state.slot_in_use = st.session_state.get(
        "slot_in_use", "a"
    )
    if app_state != st.session_state.get("previous_state", app_state):
        if slot_in_use == "a":
            slot_in_use = st.session_state.slot_in_use = "b"
        else:
            slot_in_use = st.session_state.slot_in_use = "a"

    st.session_state.previous_state = app_state

    slot = {
        "a": st.empty(),
        "b": st.empty(),
    }[slot_in_use]

    return slot.container()



def create_image_overlay(raster_path_or_array, name="Raster", opacity=1.0, to_crs=4326, show=True):
    """ Create a folium image overlay from a raster filepath or xarray.DataArray. """
    if isinstance(raster_path_or_array, str):
        # Open the raster and its metadata
        r = rxr.open_rasterio(raster_path_or_array)
    else:
        r = raster_path_or_array
    nodata = r.rio.nodata or 0
    if r.rio.crs.to_epsg() != to_crs:
        r = r.rio.reproject(to_crs, nodata=nodata) # nodata default: 255
    r = r.transpose("y", "x", "band")
    bounds = r.rio.bounds()   # (left, bottom, right, top)

    # Create a folium image overlay
    overlay = folium.raster_layers.ImageOverlay(
        image=r.to_numpy(),
        name=name,
        bounds=[[bounds[1], bounds[0]], [bounds[3], bounds[2]]],    # format for folium: ((bottom,left),(top,right))
        opacity=opacity,
        interactive=True,
        cross_origin=False,
        zindex=1,
        show=show,
    )

    return overlay



@st.cache_resource
def process_raster_and_overlays(
        raster_path: str,
        _model: torch.nn.Module,
        patch_size=512,
        stride=256,
        scaling_factor=None,
        rotate=False,
        batch_size=16,
        window_size=360,
        dilate_rows=False,
        _progress_bar=None,
        ):
    
    # Define paths for mask, vNDVI, and VDI
    mask_path = raster_path.replace('.tif', '_mask.tif')
    vndvi_rows_path = raster_path.replace('.tif', '_vndvi_rows.tif')
    vndvi_interrows_path = raster_path.replace('.tif', '_vndvi_interrows.tif')
    vdi_path = raster_path.replace('.tif', '_vdi.tif')
    if os.path.exists(mask_path):
        assert os.path.exists(vndvi_rows_path)
        assert os.path.exists(vndvi_interrows_path)
        assert os.path.exists(vdi_path)
        logger.info(f"Found mask at {mask_path!r}, vNDVI at {vndvi_rows_path!r} and {vndvi_interrows_path!r}, and VDI at {vdi_path!r}. Loading...")

    # Read raster
    logger.info(f'Reading raster image {raster_path!r}...')
    if _progress_bar: _progress_bar.progress(0, text=f'Reading raster image {raster_path!r}...')
    raster = rxr.open_rasterio(raster_path)

    # Compute mask
    logger.info('### Computing mask...')
    if _progress_bar: _progress_bar.progress(10, text='### Computing mask...')

    
    if os.path.exists(mask_path):
        mask_raster = rxr.open_rasterio(mask_path)    # mask is RGBA (red for vine)
    else:
        mask = compute_mask(
            raster.to_numpy(),
            _model,
            patch_size=patch_size,
            stride=stride,
            scaling_factor=scaling_factor,
            rotate=rotate,
            batch_size=batch_size
        )   # mask is a HxW uint8 array in with 0=background, 255=vine, 1=nodata

        # Convert mask from grayscale to RGBA, with red pixels for vine
        alpha = ((mask != 1)*255).astype(np.uint8)
        mask_colored = np.stack([mask, np.zeros_like(mask), np.zeros_like(mask), alpha], axis=0)  # now, mask is a 4xHxW uint8 array in with 0=background, 255=vine

        # Georef mask like raster
        logger.info('Georeferencing mask...')
        if _progress_bar: _progress_bar.progress(30, text='Georeferencing mask...')
        mask_raster = xr.DataArray(
            mask_colored,
            dims=('band', 'y', 'x'),
            coords={'x': raster.x, 'y': raster.y, 'band': raster.band}
            )
        mask_raster.rio.write_crs(raster.rio.crs, inplace=True)  # Copy CRS
        mask_raster.rio.write_transform(raster.rio.transform(), inplace=True)  # Copy affine transform

    # Compute vNDVI
    logger.info('### Computing vNDVI...')
    if _progress_bar: _progress_bar.progress(35, text='### Computing vNDVI...')

    if os.path.exists(vndvi_rows_path) and os.path.exists(vndvi_interrows_path):
        vndvi_rows_raster = rxr.open_rasterio(vndvi_rows_path)    # vNDVI is RGBA
        vndvi_interrows_raster = rxr.open_rasterio(vndvi_interrows_path)  # vNDVI is RGBA
    else:
        vndvi_rows, vndvi_interrows = compute_vndvi(
            raster.to_numpy(),
            mask,
            dilate_rows=dilate_rows,
            window_size=window_size
            )    # vNDVI is RGBA

        # Georef vNDVI like raster
        logger.info('Georeferencing vNDVI...')
        if _progress_bar: _progress_bar.progress(55, text='Georeferencing vNDVI...')
        vndvi_rows_raster = xr.DataArray(
            vndvi_rows.transpose(2, 0, 1),
            dims=('band', 'y', 'x'),
            coords={'x': raster.x, 'y': raster.y, 'band': raster.band}
            )
        vndvi_rows_raster.rio.write_crs(raster.rio.crs, inplace=True)
        vndvi_rows_raster.rio.write_transform(raster.rio.transform(), inplace=True)

        vndvi_interrows_raster = xr.DataArray(
            vndvi_interrows.transpose(2, 0, 1),
            dims=('band', 'y', 'x'),
            coords={'x': raster.x, 'y': raster.y, 'band': raster.band}
            )
        vndvi_interrows_raster.rio.write_crs(raster.rio.crs, inplace=True)
        vndvi_interrows_raster.rio.write_transform(raster.rio.transform(), inplace=True)

    # Compute VDI
    logger.info('### Computing VDI...')
    if _progress_bar: _progress_bar.progress(60, text='### Computing VDI...')

    if os.path.exists(vdi_path):
        vdi_raster = rxr.open_rasterio(vdi_path)    # VDI is RGBA
    else:
        vdi = compute_vdi(
            raster.to_numpy(),
            mask,
            window_size=window_size
            )    # VDI is RGBA

        # Georef VDI like raster
        logger.info('Georeferencing VDI...')
        if _progress_bar: _progress_bar.progress(80, text='Georeferencing VDI...')
        vdi_raster = xr.DataArray(
            vdi.transpose(2, 0, 1),
            dims=('band', 'y', 'x'),
            coords={'x': raster.x, 'y': raster.y, 'band': raster.band}
            )
        vdi_raster.rio.write_crs(raster.rio.crs, inplace=True)
        vdi_raster.rio.write_transform(raster.rio.transform(), inplace=True)

    # Reproject all rasters to EPSG:4326
    if raster.rio.crs.to_epsg() != 4326:
        logger.info(f"Reprojecting rasters to EPSG:4326 with NODATA value 0...")
        if _progress_bar: _progress_bar.progress(82, text=f"Reprojecting rasters to EPSG:4326 with NODATA value 0...")
        raster = raster.rio.reproject("EPSG:4326", nodata=0)    # nodata default: 255
        mask_raster = mask_raster.rio.reproject("EPSG:4326", nodata=0)
        vndvi_rows_raster = vndvi_rows_raster.rio.reproject("EPSG:4326", nodata=0)
        vndvi_interrows_raster = vndvi_interrows_raster.rio.reproject("EPSG:4326", nodata=0)
        vdi_raster = vdi_raster.rio.reproject("EPSG:4326", nodata=0)

    # Create overlays
    logger.info(f'Creating RGB raster overlay...')
    if _progress_bar: _progress_bar.progress(85, text='Creating overlays: drone image...')
    raster_overlay = create_image_overlay(raster, name="Orthoimage", opacity=1.0, show=True)
    logger.info(f'Creating mask overlay...')
    if _progress_bar: _progress_bar.progress(88, text='Creating overlays: mask...')
    mask_overlay = create_image_overlay(mask_raster, name="Mask", opacity=1.0, show=False)
    logger.info(f'Creating vNDVI rows overlay...')
    if _progress_bar: _progress_bar.progress(91, text='Creating overlays: vNDVI (rows)...')
    vndvi_rows_overlay = create_image_overlay(vndvi_rows_raster, name="vNDVI Rows", opacity=1.0, show=False)
    logger.info(f'Creating vNDVI interrows overlay...')
    if _progress_bar: _progress_bar.progress(94, text='Creating overlays: vNDVI (interrows)...')
    vndvi_interrows_overlay = create_image_overlay(vndvi_interrows_raster, name="vNDVI Interrows", opacity=1.0, show=False)
    logger.info(f'Creating VDI overlay...')
    if _progress_bar: _progress_bar.progress(97, text='Creating overlays: VDI...')
    vdi_overlay = create_image_overlay(vdi_raster, name="VDI", opacity=1.0, show=False)

    logger.info('Done!')
    if _progress_bar: _progress_bar.progress(100, text='Done!')

    return [raster_overlay, mask_overlay, vndvi_rows_overlay, vndvi_interrows_overlay, vdi_overlay]