gaia-growseg-demo / lib /folium.py
tommonopoli's picture
update requirements.txt
46818a5
#import branca
import folium
import streamlit as st
from loguru import logger
import rioxarray as rxr
import numpy as np
import xarray as xr
import torch
from .utils import compute_mask, compute_vndvi, compute_vdi
import os
@st.cache_resource
def create_map(location=[41.9099533, 12.3711879], zoom_start=5, crs=3857, max_zoom=23):
"""Create a folium map with OpenStreetMap tiles and optional Esri.WorldImagery basemap."""
if isinstance(crs, int):
crs = f"EPSG{crs}"
assert crs in ["EPSG3857"], f"Only EPSG:3857 supported for now. Got {crs}."
m = folium.Map(
location=location,
zoom_start=zoom_start,
crs=crs,
max_zoom=max_zoom,
tiles="OpenStreetMap", # Esri.WorldImagery
attributionControl=False,
prefer_canvas=True,
)
# Add Esri.WorldImagery as optional basemap (radio button)
folium.TileLayer(
tiles="Esri.WorldImagery",
show=False,
overlay=False,
control=True,
).add_to(m)
return m
def get_clean_rendering_container(app_state: str):
"""Makes sure we can render from a clean slate on state changes."""
slot_in_use = st.session_state.slot_in_use = st.session_state.get(
"slot_in_use", "a"
)
if app_state != st.session_state.get("previous_state", app_state):
if slot_in_use == "a":
slot_in_use = st.session_state.slot_in_use = "b"
else:
slot_in_use = st.session_state.slot_in_use = "a"
st.session_state.previous_state = app_state
slot = {
"a": st.empty(),
"b": st.empty(),
}[slot_in_use]
return slot.container()
def create_image_overlay(raster_path_or_array, name="Raster", opacity=1.0, to_crs=4326, show=True):
""" Create a folium image overlay from a raster filepath or xarray.DataArray. """
if isinstance(raster_path_or_array, str):
# Open the raster and its metadata
r = rxr.open_rasterio(raster_path_or_array)
else:
r = raster_path_or_array
nodata = r.rio.nodata or 0
if r.rio.crs.to_epsg() != to_crs:
r = r.rio.reproject(to_crs, nodata=nodata) # nodata default: 255
r = r.transpose("y", "x", "band")
bounds = r.rio.bounds() # (left, bottom, right, top)
# Create a folium image overlay
overlay = folium.raster_layers.ImageOverlay(
image=r.to_numpy(),
name=name,
bounds=[[bounds[1], bounds[0]], [bounds[3], bounds[2]]], # format for folium: ((bottom,left),(top,right))
opacity=opacity,
interactive=True,
cross_origin=False,
zindex=1,
show=show,
)
return overlay
@st.cache_resource
def process_raster_and_overlays(
raster_path: str,
_model: torch.nn.Module,
patch_size=512,
stride=256,
scaling_factor=None,
rotate=False,
batch_size=16,
window_size=360,
dilate_rows=False,
_progress_bar=None,
):
# Define paths for mask, vNDVI, and VDI
mask_path = raster_path.replace('.tif', '_mask.tif')
vndvi_rows_path = raster_path.replace('.tif', '_vndvi_rows.tif')
vndvi_interrows_path = raster_path.replace('.tif', '_vndvi_interrows.tif')
vdi_path = raster_path.replace('.tif', '_vdi.tif')
if os.path.exists(mask_path):
assert os.path.exists(vndvi_rows_path)
assert os.path.exists(vndvi_interrows_path)
assert os.path.exists(vdi_path)
logger.info(f"Found mask at {mask_path!r}, vNDVI at {vndvi_rows_path!r} and {vndvi_interrows_path!r}, and VDI at {vdi_path!r}. Loading...")
# Read raster
logger.info(f'Reading raster image {raster_path!r}...')
if _progress_bar: _progress_bar.progress(0, text=f'Reading raster image {raster_path!r}...')
raster = rxr.open_rasterio(raster_path)
# Compute mask
logger.info('### Computing mask...')
if _progress_bar: _progress_bar.progress(10, text='### Computing mask...')
if os.path.exists(mask_path):
mask_raster = rxr.open_rasterio(mask_path) # mask is RGBA (red for vine)
else:
mask = compute_mask(
raster.to_numpy(),
_model,
patch_size=patch_size,
stride=stride,
scaling_factor=scaling_factor,
rotate=rotate,
batch_size=batch_size
) # mask is a HxW uint8 array in with 0=background, 255=vine, 1=nodata
# Convert mask from grayscale to RGBA, with red pixels for vine
alpha = ((mask != 1)*255).astype(np.uint8)
mask_colored = np.stack([mask, np.zeros_like(mask), np.zeros_like(mask), alpha], axis=0) # now, mask is a 4xHxW uint8 array in with 0=background, 255=vine
# Georef mask like raster
logger.info('Georeferencing mask...')
if _progress_bar: _progress_bar.progress(30, text='Georeferencing mask...')
mask_raster = xr.DataArray(
mask_colored,
dims=('band', 'y', 'x'),
coords={'x': raster.x, 'y': raster.y, 'band': raster.band}
)
mask_raster.rio.write_crs(raster.rio.crs, inplace=True) # Copy CRS
mask_raster.rio.write_transform(raster.rio.transform(), inplace=True) # Copy affine transform
# Compute vNDVI
logger.info('### Computing vNDVI...')
if _progress_bar: _progress_bar.progress(35, text='### Computing vNDVI...')
if os.path.exists(vndvi_rows_path) and os.path.exists(vndvi_interrows_path):
vndvi_rows_raster = rxr.open_rasterio(vndvi_rows_path) # vNDVI is RGBA
vndvi_interrows_raster = rxr.open_rasterio(vndvi_interrows_path) # vNDVI is RGBA
else:
vndvi_rows, vndvi_interrows = compute_vndvi(
raster.to_numpy(),
mask,
dilate_rows=dilate_rows,
window_size=window_size
) # vNDVI is RGBA
# Georef vNDVI like raster
logger.info('Georeferencing vNDVI...')
if _progress_bar: _progress_bar.progress(55, text='Georeferencing vNDVI...')
vndvi_rows_raster = xr.DataArray(
vndvi_rows.transpose(2, 0, 1),
dims=('band', 'y', 'x'),
coords={'x': raster.x, 'y': raster.y, 'band': raster.band}
)
vndvi_rows_raster.rio.write_crs(raster.rio.crs, inplace=True)
vndvi_rows_raster.rio.write_transform(raster.rio.transform(), inplace=True)
vndvi_interrows_raster = xr.DataArray(
vndvi_interrows.transpose(2, 0, 1),
dims=('band', 'y', 'x'),
coords={'x': raster.x, 'y': raster.y, 'band': raster.band}
)
vndvi_interrows_raster.rio.write_crs(raster.rio.crs, inplace=True)
vndvi_interrows_raster.rio.write_transform(raster.rio.transform(), inplace=True)
# Compute VDI
logger.info('### Computing VDI...')
if _progress_bar: _progress_bar.progress(60, text='### Computing VDI...')
if os.path.exists(vdi_path):
vdi_raster = rxr.open_rasterio(vdi_path) # VDI is RGBA
else:
vdi = compute_vdi(
raster.to_numpy(),
mask,
window_size=window_size
) # VDI is RGBA
# Georef VDI like raster
logger.info('Georeferencing VDI...')
if _progress_bar: _progress_bar.progress(80, text='Georeferencing VDI...')
vdi_raster = xr.DataArray(
vdi.transpose(2, 0, 1),
dims=('band', 'y', 'x'),
coords={'x': raster.x, 'y': raster.y, 'band': raster.band}
)
vdi_raster.rio.write_crs(raster.rio.crs, inplace=True)
vdi_raster.rio.write_transform(raster.rio.transform(), inplace=True)
# Reproject all rasters to EPSG:4326
if raster.rio.crs.to_epsg() != 4326:
logger.info(f"Reprojecting rasters to EPSG:4326 with NODATA value 0...")
if _progress_bar: _progress_bar.progress(82, text=f"Reprojecting rasters to EPSG:4326 with NODATA value 0...")
raster = raster.rio.reproject("EPSG:4326", nodata=0) # nodata default: 255
mask_raster = mask_raster.rio.reproject("EPSG:4326", nodata=0)
vndvi_rows_raster = vndvi_rows_raster.rio.reproject("EPSG:4326", nodata=0)
vndvi_interrows_raster = vndvi_interrows_raster.rio.reproject("EPSG:4326", nodata=0)
vdi_raster = vdi_raster.rio.reproject("EPSG:4326", nodata=0)
# Create overlays
logger.info(f'Creating RGB raster overlay...')
if _progress_bar: _progress_bar.progress(85, text='Creating overlays: drone image...')
raster_overlay = create_image_overlay(raster, name="Orthoimage", opacity=1.0, show=True)
logger.info(f'Creating mask overlay...')
if _progress_bar: _progress_bar.progress(88, text='Creating overlays: mask...')
mask_overlay = create_image_overlay(mask_raster, name="Mask", opacity=1.0, show=False)
logger.info(f'Creating vNDVI rows overlay...')
if _progress_bar: _progress_bar.progress(91, text='Creating overlays: vNDVI (rows)...')
vndvi_rows_overlay = create_image_overlay(vndvi_rows_raster, name="vNDVI Rows", opacity=1.0, show=False)
logger.info(f'Creating vNDVI interrows overlay...')
if _progress_bar: _progress_bar.progress(94, text='Creating overlays: vNDVI (interrows)...')
vndvi_interrows_overlay = create_image_overlay(vndvi_interrows_raster, name="vNDVI Interrows", opacity=1.0, show=False)
logger.info(f'Creating VDI overlay...')
if _progress_bar: _progress_bar.progress(97, text='Creating overlays: VDI...')
vdi_overlay = create_image_overlay(vdi_raster, name="VDI", opacity=1.0, show=False)
logger.info('Done!')
if _progress_bar: _progress_bar.progress(100, text='Done!')
return [raster_overlay, mask_overlay, vndvi_rows_overlay, vndvi_interrows_overlay, vdi_overlay]