Spaces:
Starting
Starting
#import branca | |
import folium | |
import streamlit as st | |
from loguru import logger | |
import rioxarray as rxr | |
import numpy as np | |
import xarray as xr | |
import torch | |
from .utils import compute_mask, compute_vndvi, compute_vdi | |
import os | |
def create_map(location=[41.9099533, 12.3711879], zoom_start=5, crs=3857, max_zoom=23): | |
"""Create a folium map with OpenStreetMap tiles and optional Esri.WorldImagery basemap.""" | |
if isinstance(crs, int): | |
crs = f"EPSG{crs}" | |
assert crs in ["EPSG3857"], f"Only EPSG:3857 supported for now. Got {crs}." | |
m = folium.Map( | |
location=location, | |
zoom_start=zoom_start, | |
crs=crs, | |
max_zoom=max_zoom, | |
tiles="OpenStreetMap", # Esri.WorldImagery | |
attributionControl=False, | |
prefer_canvas=True, | |
) | |
# Add Esri.WorldImagery as optional basemap (radio button) | |
folium.TileLayer( | |
tiles="Esri.WorldImagery", | |
show=False, | |
overlay=False, | |
control=True, | |
).add_to(m) | |
return m | |
def get_clean_rendering_container(app_state: str): | |
"""Makes sure we can render from a clean slate on state changes.""" | |
slot_in_use = st.session_state.slot_in_use = st.session_state.get( | |
"slot_in_use", "a" | |
) | |
if app_state != st.session_state.get("previous_state", app_state): | |
if slot_in_use == "a": | |
slot_in_use = st.session_state.slot_in_use = "b" | |
else: | |
slot_in_use = st.session_state.slot_in_use = "a" | |
st.session_state.previous_state = app_state | |
slot = { | |
"a": st.empty(), | |
"b": st.empty(), | |
}[slot_in_use] | |
return slot.container() | |
def create_image_overlay(raster_path_or_array, name="Raster", opacity=1.0, to_crs=4326, show=True): | |
""" Create a folium image overlay from a raster filepath or xarray.DataArray. """ | |
if isinstance(raster_path_or_array, str): | |
# Open the raster and its metadata | |
r = rxr.open_rasterio(raster_path_or_array) | |
else: | |
r = raster_path_or_array | |
nodata = r.rio.nodata or 0 | |
if r.rio.crs.to_epsg() != to_crs: | |
r = r.rio.reproject(to_crs, nodata=nodata) # nodata default: 255 | |
r = r.transpose("y", "x", "band") | |
bounds = r.rio.bounds() # (left, bottom, right, top) | |
# Create a folium image overlay | |
overlay = folium.raster_layers.ImageOverlay( | |
image=r.to_numpy(), | |
name=name, | |
bounds=[[bounds[1], bounds[0]], [bounds[3], bounds[2]]], # format for folium: ((bottom,left),(top,right)) | |
opacity=opacity, | |
interactive=True, | |
cross_origin=False, | |
zindex=1, | |
show=show, | |
) | |
return overlay | |
def process_raster_and_overlays( | |
raster_path: str, | |
_model: torch.nn.Module, | |
patch_size=512, | |
stride=256, | |
scaling_factor=None, | |
rotate=False, | |
batch_size=16, | |
window_size=360, | |
dilate_rows=False, | |
_progress_bar=None, | |
): | |
# Define paths for mask, vNDVI, and VDI | |
mask_path = raster_path.replace('.tif', '_mask.tif') | |
vndvi_rows_path = raster_path.replace('.tif', '_vndvi_rows.tif') | |
vndvi_interrows_path = raster_path.replace('.tif', '_vndvi_interrows.tif') | |
vdi_path = raster_path.replace('.tif', '_vdi.tif') | |
if os.path.exists(mask_path): | |
assert os.path.exists(vndvi_rows_path) | |
assert os.path.exists(vndvi_interrows_path) | |
assert os.path.exists(vdi_path) | |
logger.info(f"Found mask at {mask_path!r}, vNDVI at {vndvi_rows_path!r} and {vndvi_interrows_path!r}, and VDI at {vdi_path!r}. Loading...") | |
# Read raster | |
logger.info(f'Reading raster image {raster_path!r}...') | |
if _progress_bar: _progress_bar.progress(0, text=f'Reading raster image {raster_path!r}...') | |
raster = rxr.open_rasterio(raster_path) | |
# Compute mask | |
logger.info('### Computing mask...') | |
if _progress_bar: _progress_bar.progress(10, text='### Computing mask...') | |
if os.path.exists(mask_path): | |
mask_raster = rxr.open_rasterio(mask_path) # mask is RGBA (red for vine) | |
else: | |
mask = compute_mask( | |
raster.to_numpy(), | |
_model, | |
patch_size=patch_size, | |
stride=stride, | |
scaling_factor=scaling_factor, | |
rotate=rotate, | |
batch_size=batch_size | |
) # mask is a HxW uint8 array in with 0=background, 255=vine, 1=nodata | |
# Convert mask from grayscale to RGBA, with red pixels for vine | |
alpha = ((mask != 1)*255).astype(np.uint8) | |
mask_colored = np.stack([mask, np.zeros_like(mask), np.zeros_like(mask), alpha], axis=0) # now, mask is a 4xHxW uint8 array in with 0=background, 255=vine | |
# Georef mask like raster | |
logger.info('Georeferencing mask...') | |
if _progress_bar: _progress_bar.progress(30, text='Georeferencing mask...') | |
mask_raster = xr.DataArray( | |
mask_colored, | |
dims=('band', 'y', 'x'), | |
coords={'x': raster.x, 'y': raster.y, 'band': raster.band} | |
) | |
mask_raster.rio.write_crs(raster.rio.crs, inplace=True) # Copy CRS | |
mask_raster.rio.write_transform(raster.rio.transform(), inplace=True) # Copy affine transform | |
# Compute vNDVI | |
logger.info('### Computing vNDVI...') | |
if _progress_bar: _progress_bar.progress(35, text='### Computing vNDVI...') | |
if os.path.exists(vndvi_rows_path) and os.path.exists(vndvi_interrows_path): | |
vndvi_rows_raster = rxr.open_rasterio(vndvi_rows_path) # vNDVI is RGBA | |
vndvi_interrows_raster = rxr.open_rasterio(vndvi_interrows_path) # vNDVI is RGBA | |
else: | |
vndvi_rows, vndvi_interrows = compute_vndvi( | |
raster.to_numpy(), | |
mask, | |
dilate_rows=dilate_rows, | |
window_size=window_size | |
) # vNDVI is RGBA | |
# Georef vNDVI like raster | |
logger.info('Georeferencing vNDVI...') | |
if _progress_bar: _progress_bar.progress(55, text='Georeferencing vNDVI...') | |
vndvi_rows_raster = xr.DataArray( | |
vndvi_rows.transpose(2, 0, 1), | |
dims=('band', 'y', 'x'), | |
coords={'x': raster.x, 'y': raster.y, 'band': raster.band} | |
) | |
vndvi_rows_raster.rio.write_crs(raster.rio.crs, inplace=True) | |
vndvi_rows_raster.rio.write_transform(raster.rio.transform(), inplace=True) | |
vndvi_interrows_raster = xr.DataArray( | |
vndvi_interrows.transpose(2, 0, 1), | |
dims=('band', 'y', 'x'), | |
coords={'x': raster.x, 'y': raster.y, 'band': raster.band} | |
) | |
vndvi_interrows_raster.rio.write_crs(raster.rio.crs, inplace=True) | |
vndvi_interrows_raster.rio.write_transform(raster.rio.transform(), inplace=True) | |
# Compute VDI | |
logger.info('### Computing VDI...') | |
if _progress_bar: _progress_bar.progress(60, text='### Computing VDI...') | |
if os.path.exists(vdi_path): | |
vdi_raster = rxr.open_rasterio(vdi_path) # VDI is RGBA | |
else: | |
vdi = compute_vdi( | |
raster.to_numpy(), | |
mask, | |
window_size=window_size | |
) # VDI is RGBA | |
# Georef VDI like raster | |
logger.info('Georeferencing VDI...') | |
if _progress_bar: _progress_bar.progress(80, text='Georeferencing VDI...') | |
vdi_raster = xr.DataArray( | |
vdi.transpose(2, 0, 1), | |
dims=('band', 'y', 'x'), | |
coords={'x': raster.x, 'y': raster.y, 'band': raster.band} | |
) | |
vdi_raster.rio.write_crs(raster.rio.crs, inplace=True) | |
vdi_raster.rio.write_transform(raster.rio.transform(), inplace=True) | |
# Reproject all rasters to EPSG:4326 | |
if raster.rio.crs.to_epsg() != 4326: | |
logger.info(f"Reprojecting rasters to EPSG:4326 with NODATA value 0...") | |
if _progress_bar: _progress_bar.progress(82, text=f"Reprojecting rasters to EPSG:4326 with NODATA value 0...") | |
raster = raster.rio.reproject("EPSG:4326", nodata=0) # nodata default: 255 | |
mask_raster = mask_raster.rio.reproject("EPSG:4326", nodata=0) | |
vndvi_rows_raster = vndvi_rows_raster.rio.reproject("EPSG:4326", nodata=0) | |
vndvi_interrows_raster = vndvi_interrows_raster.rio.reproject("EPSG:4326", nodata=0) | |
vdi_raster = vdi_raster.rio.reproject("EPSG:4326", nodata=0) | |
# Create overlays | |
logger.info(f'Creating RGB raster overlay...') | |
if _progress_bar: _progress_bar.progress(85, text='Creating overlays: drone image...') | |
raster_overlay = create_image_overlay(raster, name="Orthoimage", opacity=1.0, show=True) | |
logger.info(f'Creating mask overlay...') | |
if _progress_bar: _progress_bar.progress(88, text='Creating overlays: mask...') | |
mask_overlay = create_image_overlay(mask_raster, name="Mask", opacity=1.0, show=False) | |
logger.info(f'Creating vNDVI rows overlay...') | |
if _progress_bar: _progress_bar.progress(91, text='Creating overlays: vNDVI (rows)...') | |
vndvi_rows_overlay = create_image_overlay(vndvi_rows_raster, name="vNDVI Rows", opacity=1.0, show=False) | |
logger.info(f'Creating vNDVI interrows overlay...') | |
if _progress_bar: _progress_bar.progress(94, text='Creating overlays: vNDVI (interrows)...') | |
vndvi_interrows_overlay = create_image_overlay(vndvi_interrows_raster, name="vNDVI Interrows", opacity=1.0, show=False) | |
logger.info(f'Creating VDI overlay...') | |
if _progress_bar: _progress_bar.progress(97, text='Creating overlays: VDI...') | |
vdi_overlay = create_image_overlay(vdi_raster, name="VDI", opacity=1.0, show=False) | |
logger.info('Done!') | |
if _progress_bar: _progress_bar.progress(100, text='Done!') | |
return [raster_overlay, mask_overlay, vndvi_rows_overlay, vndvi_interrows_overlay, vdi_overlay] |