Spaces:
Starting
Starting
import numpy as np | |
import torch | |
import rasterio | |
import xarray as xr | |
import rioxarray as rxr | |
import cv2 | |
from transformers import SegformerForSemanticSegmentation | |
from tqdm import tqdm | |
from scipy.ndimage import grey_dilation | |
import matplotlib as mpl | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.axes_grid1 import make_axes_locatable | |
from .viz_utils import alpha_composite | |
from loguru import logger | |
def resize(img, shape=None, scaling_factor=1., order='CHW'): | |
"""Resize an image by a given scaling factor""" | |
assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" | |
assert shape is None or scaling_factor == 1., "Got both shape and scaling_factor. Please provide only one of them" | |
# resize image | |
if order == 'CHW': | |
img = np.moveaxis(img, 0, -1) # CHW -> HWC | |
if shape is not None: | |
img = cv2.resize(img, shape[::-1], interpolation=cv2.INTER_LINEAR) | |
else: | |
img = cv2.resize(img, None, fx=scaling_factor, fy=scaling_factor, interpolation=cv2.INTER_LINEAR) | |
# NB: cv2.resize returns a HW image if the input image is HW1: restore the C dimension | |
if len(img.shape) == 2: | |
img = img[..., None] | |
if order == 'CHW': | |
img = np.moveaxis(img, -1, 0) # HWC -> CHW | |
return img | |
def minimum_needed_padding(img_size, patch_size: int, stride: int): | |
""" | |
Compute the minimum padding needed to make an image divisible by a patch size with a given stride. | |
Args: | |
image_shape (tuple): the shape (H,W) of the image tensor | |
patch_size (int): the size of the patches to extract | |
stride (int): the stride to use when extracting patches | |
Returns: | |
tuple: the padding needed to make the image tensor divisible by the patch size with the given stride | |
""" | |
img_size = np.array(img_size) | |
pad = np.where( | |
img_size <= patch_size, | |
(patch_size - img_size) % patch_size, # the % patch_size is to handle the case img_size = (0,0) | |
(stride - (img_size - patch_size)) % stride | |
) | |
pad_t, pad_l = pad // 2 | |
pad_b, pad_r = pad[0] - pad_t, pad[1] - pad_l | |
return pad_t, pad_b, pad_l, pad_r | |
def pad(img, pad, order='CHW'): | |
"""Pad an image by the given pad values, in the format (pad_t, pad_b, pad_l, pad_r)""" | |
assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" | |
pad_t, pad_b, pad_l, pad_r = pad | |
# pad image | |
if order == 'HWC': | |
padded_img = np.pad(img, ((pad_t,pad_b), (pad_l,pad_r), (0,0)), mode='constant', constant_values=0) # can also try mode='reflect' | |
else: | |
padded_img = np.pad(img, ((0,0), (pad_t,pad_b), (pad_l,pad_r)), mode='constant', constant_values=0) # can also try mode='reflect' | |
if isinstance(img, torch.Tensor): | |
padded_img = torch.tensor(padded_img) | |
return padded_img | |
def extract_patches(img, patch_size=512, stride=256, order='CHW', only_return_idx=True, include_last=True): | |
"""Extract patches from an image, in the format (h_start, h_end, w_start, w_end)""" | |
assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" | |
assert len(img.shape) == 3, f"Got image with {len(img.shape)} dimensions, expected 3 dimensions (C,H,W) or (H,W,C)" | |
assert img.shape[0] >= patch_size, f"Got image with height {img.shape[0]}, expected at least {patch_size}. Maybe apply padding first?" | |
assert img.shape[1] >= patch_size, f"Got image with width {img.shape[1]}, expected at least {patch_size}. Maybe apply padding first?" | |
# Get image height and width | |
if order == 'HWC': | |
H, W = img.shape[:2] | |
else: | |
H, W = img.shape[1:] | |
# Compute the number of "proper" patches in each dimension | |
n_patches_H = (H - patch_size) // stride + 1 | |
n_patches_W = (W - patch_size) // stride + 1 | |
# Extract patches indices | |
patches_idx = [] | |
for i in range(n_patches_H): # iterate over height | |
for j in range(n_patches_W): # iterate over width | |
# Get the current patch indices | |
patches_idx.append((i*stride, i*stride+patch_size, j*stride, j*stride+patch_size)) # (top, bottom, left, right) | |
# Include leftmost and lowermost patch if needed | |
if include_last: | |
if j == n_patches_W-1 and j*stride+patch_size < W: | |
patches_idx.append((i*stride, i*stride+patch_size, W-patch_size, W)) | |
if i == n_patches_H-1 and i*stride+patch_size < H: | |
patches_idx.append((H-patch_size, H, j*stride, j*stride+patch_size)) | |
if i == n_patches_H-1 and j == n_patches_W-1 and i*stride+patch_size < H and j*stride+patch_size < W: | |
patches_idx.append((H-patch_size, H, W-patch_size, W)) | |
if only_return_idx: | |
return patches_idx | |
else: | |
# Extract patches | |
patches = [] | |
for t,b,l,r in patches_idx: | |
if order == 'HWC': | |
patch = img[t:b, l:r, :] | |
else: | |
patch = img[:, t:b, l:r] | |
patches.append(patch) | |
return patches, patches_idx | |
def segment_batch(batch, model): | |
# perform prediction | |
with torch.no_grad(): | |
out = model(batch) # (n_patches, 1, H, W) logits | |
if isinstance(model, SegformerForSemanticSegmentation): | |
out = upsample(out.logits, size=batch.shape[-2:]) | |
# apply sigmoid | |
out = torch.sigmoid(out) # logits -> confidence scores | |
return out | |
def upsample(x, size): | |
"""Upsample a 3D/4D/5D tensor""" | |
return torch.nn.functional.interpolate(x, size=size, mode='bilinear', align_corners=False) | |
def merge_patches(patches, patches_idx, rotate=False, canvas_shape=None, order='CHW'): # TODO | |
"""Merge patches into a single image""" | |
assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" | |
if rotate: | |
axes_to_rotate = (0,1) if order == 'HWC' else (1,2) | |
patches = [np.rot90(p, -i, axes=axes_to_rotate) for i,p in enumerate(patches)] | |
else: | |
assert len(patches) == len(patches_idx), f"Got {len(patches)} patches and {len(patches_idx)} indexes" | |
# if canvas_shape is None, infer it from patches_idx | |
if canvas_shape is None: | |
patches_idx_zipped = list(zip(*patches_idx)) | |
canvas_H = max(patches_idx_zipped[1]) | |
canvas_W = max(patches_idx_zipped[3]) | |
else: | |
canvas_H, canvas_W = canvas_shape | |
# initialize canvas | |
dtype = patches[0].dtype | |
if order == 'HWC': | |
canvas_C = patches[0].shape[-1] | |
canvas = np.zeros((canvas_H, canvas_W, canvas_C), dtype=dtype) # HWC | |
n_overlapping_patches = np.zeros((canvas_H, canvas_W, 1)) | |
else: | |
canvas_C = patches[0].shape[0] | |
canvas = np.zeros((canvas_C, canvas_H, canvas_W, ), dtype=dtype) # CHW | |
n_overlapping_patches = np.zeros((1, canvas_H, canvas_W)) | |
# merge patches | |
for p, (t,b,l,r) in zip(patches, patches_idx): | |
if order == 'HWC': | |
canvas[t:b, l:r, :] += p | |
n_overlapping_patches[t:b, l:r, 0] += 1 | |
else: | |
canvas[:, t:b, l:r] += p | |
n_overlapping_patches[0, t:b, l:r] += 1 | |
# compute average | |
canvas = np.divide(canvas, n_overlapping_patches, where=(n_overlapping_patches != 0)) | |
return canvas | |
def segment(img, model, patch_size=512, stride=256, scaling_factor=1., rotate=False, device=None, batch_size=16, verbose=False): | |
"""Segment an RGB image by using a segmentation model. Returns a probability | |
map (and performance metrics, if requested)""" | |
# some checks | |
assert isinstance(img, np.ndarray), f"Input must be a numpy array. Got {type(img)}" | |
assert img.shape[0] in [3,4], f"Input image must be formatted as CHW, with C = 3,4. Got a shape of {img.shape}" | |
assert img.dtype == np.uint8, f"Input image must be a numpy array with dtype np.uint8. Got {img.dtype}" | |
# prepare model for evaluation | |
model = model.to(device) | |
model.eval() | |
# prepare alpha channel | |
original_shape = img.shape | |
if img.shape[0] == 3: | |
# create dummy alpha channel | |
alpha = np.full(original_shape[1:], 255, dtype=np.uint8) | |
else: | |
# extract alpha channel | |
img, alpha = img[:3], img[3] | |
# resize image | |
img = resize(img, scaling_factor=scaling_factor) | |
# pad image | |
pad_t, pad_b, pad_l, pad_r = minimum_needed_padding(img.shape[1:], patch_size, stride) | |
padded_img = pad(img, pad=(pad_t, pad_b, pad_l, pad_r)) | |
padded_shape = padded_img.shape | |
# extract patches indexes | |
patches_idx = extract_patches(padded_img, patch_size=patch_size, stride=stride) | |
### segment | |
masks = [] | |
masks_idx = [] | |
batch = [] | |
for i, p_idx in enumerate(tqdm(patches_idx, disable=not verbose, desc="Predicting...", total=len(patches_idx))): | |
t, b, l, r = p_idx | |
# extract patch | |
patch = padded_img[:, t:b, l:r] | |
# consider patch only if it is valid (i.e. not all black or all white) | |
if np.any(patch != 0) and np.any(patch != 255): | |
# convert patch to torch.tensor with float32 values in [0,1] (as required by torch) | |
patch = torch.tensor(patch).float() / 255. | |
# normalize patch with ImageNet mean and std | |
patch = (patch - torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)) / torch.tensor([0.229, 0.224, 0.225]).view(3,1,1) | |
# add patch to batch | |
batch.append(patch) | |
masks_idx.append(p_idx) | |
# (optional) for each patch extracted, consider also its rotated versions | |
if rotate: | |
for rot in range(1,4): | |
patch = torch.rot90(patch, rot, dims=[1,2]) | |
batch.append(patch) | |
masks_idx.append(p_idx) | |
# if the batch is full, perform prediction | |
if len(batch) >= batch_size or i == len(patches_idx)-1: | |
# move batch to GPU | |
batch = torch.stack(batch).to(device) | |
# perform prediction | |
out = segment_batch(batch, model) | |
# append predictions to masks | |
masks.append(out.cpu().numpy()) | |
# reset batch | |
batch = [] | |
# concatenate predictions | |
masks = np.concatenate(masks) # (n_patches, 1, H, W) | |
# merge patches | |
mask = merge_patches(masks, masks_idx, rotate=rotate, canvas_shape=padded_shape[1:]) # (1, H, W) | |
# undo padding | |
mask = mask[:, pad_t:padded_shape[1]-pad_b, pad_l:padded_shape[2]-pad_r] | |
# resize mask to original shape | |
mask = resize(mask, shape=original_shape[1:]) | |
# apply alpha channel, i.e. set to -1 the pixels where alpha is 0 | |
mask = np.where(alpha == 0, -1, mask) | |
return mask.squeeze() | |
def sliding_window_avg_pooling(img, window, granularity, alpha=None, min_nonblank_pixels=0., order="HWC", normalize=False, return_min_max=False, verbose=False): | |
assert isinstance(img, np.ndarray), f'Input image must be a numpy array. Got {type(img)}' | |
if order == "HWC": | |
assert img.shape[2] == 1, f'Input image must be formatted as HWC, with C = 1. Got a shape of {img.shape}' | |
elif order == "CHW": | |
assert img.shape[0] == 1, f'Input image must be formatted as CHW, with C = 1. Got a shape of {img.shape}' | |
# check if alpha channel was given, and cast it to np.float32 with values in [0,1] | |
if alpha is not None: | |
assert img.shape == alpha.shape, f'The shape of input image {img.shape} and alpha channel {alpha.shape} do not match' | |
if alpha.dtype == np.uint8: | |
alpha = (alpha / 255).astype(np.float32) | |
elif alpha.dtype == bool: | |
alpha = alpha.astype(np.float32) | |
else: | |
alpha = np.ones_like(img, dtype=np.float32) | |
# compute threshold | |
thresh = min_nonblank_pixels * window**2 | |
# extract patches idxs | |
patches_idx = extract_patches(img, patch_size=window, stride=granularity, order=order, only_return_idx=True) | |
# initialize canvas | |
canvas = np.zeros_like(img, dtype=np.float32) | |
n_overlapping_patches = np.zeros_like(img, dtype=np.float32) | |
# cycle through patches idxs | |
for t,b,l,r in tqdm(patches_idx, disable=not verbose): | |
p_a = alpha[t:b,l:r] | |
n_valid_pixels = p_a.sum() | |
# keep only if it has more than min_nonblank_pixels | |
if n_valid_pixels <= thresh: | |
continue | |
# compute average patch value (i.e. density inside the patch) | |
p = img[t:b,l:r] | |
p_density = (p * p_a).sum() / n_valid_pixels | |
# add to canvas | |
canvas[t:b,l:r] += p_density | |
n_overlapping_patches[t:b,l:r] += 1 | |
# compute average density | |
density_map = np.divide(canvas, n_overlapping_patches, where=(n_overlapping_patches != 0)) | |
# apply alpha | |
density_map = density_map * alpha | |
if normalize: | |
# [0,1]-normalize | |
density_map_min = density_map.min() | |
density_map_max = density_map.max() | |
density_map = (density_map - density_map_min) / (density_map_max - density_map_min) | |
if return_min_max: | |
return density_map, density_map_min, density_map_max | |
return density_map | |
def compute_vndvi( | |
raster: np.ndarray, | |
mask: np.ndarray, | |
dilate_rows=True, | |
window_size=360, | |
granularity=45, | |
): | |
assert isinstance(raster, np.ndarray) | |
assert isinstance(mask, np.ndarray) | |
assert len(raster.shape) == 3 # CHW | |
assert len(mask.shape) == 2 # HW | |
assert raster.shape[0] in [3,4] # RGB or RGBA | |
# CHW -> HWC | |
raster = raster.transpose(1,2,0) | |
# Extract channels | |
_raster = raster.astype(np.float32) / 255 # convert to float32 in [0,1] | |
R, G, B = _raster[:,:,0], _raster[:,:,1], _raster[:,:,2] | |
# To avoid division by 0 due to negative power, we replace 0 with 1 in R and B channels | |
R = np.where(R == 0, 1, R) | |
B = np.where(B == 0, 1, B) | |
# Mask has values: 0=interrows, 255=rows, 1=nodata | |
# Get mask for the rows and interrows | |
mask_rows = (mask == 255) | |
mask_interrows = (mask == 0) | |
mask_valid = mask_rows | mask_interrows | |
# Compute vndvi | |
vndvi = 0.5268 * (R**(-0.1294) * G**(0.3389) * B**(-0.3118)) | |
# Clip values to [0,1] | |
vndvi = np.clip(vndvi, 0, 1) | |
# Compute 10th and 90th percentile on whole vineyard vndvi heatmap | |
vndvi_perc10, vndvi_perc90 = np.percentile(vndvi[mask_valid], [10,90]) | |
# Clip values between 10th and 90th percentile | |
vndvi_clipped = np.clip(vndvi, vndvi_perc10, vndvi_perc90) | |
# Perform sliding window average pooling to smooth the heatmap | |
# NB: the window takes into account only the rows | |
vndvi_rows_clipped_pooled = sliding_window_avg_pooling( | |
np.where(mask_rows, vndvi_clipped, 0)[..., None], | |
window = int(window_size / 4), | |
granularity = granularity, | |
alpha = mask_rows[..., None], | |
min_nonblank_pixels = 0.0, | |
verbose=True, | |
) | |
# Same, but for interrows | |
vndvi_interrows_clipped_pooled = sliding_window_avg_pooling( | |
np.where(mask_interrows, vndvi_clipped, 0)[..., None], | |
window = int(window_size / 4), | |
granularity = granularity, | |
alpha = mask_interrows[..., None], | |
min_nonblank_pixels = 0.0, | |
verbose=True, | |
) | |
# Apply dilation to rows mask | |
dil_factor = int(window_size / 60) | |
mask_rows_dilated = grey_dilation(mask_rows, size=(dil_factor, dil_factor)) | |
vndvi_rows_clipped_pooled_dilated = grey_dilation(vndvi_rows_clipped_pooled, size=(dil_factor, dil_factor, 1)) | |
# For visualization purposes, normalize with vndvi_perc10 and | |
# vndvi_perc90 (because we want vndvi_perc10 to be the first color of | |
# the colormap and vndvi_perc90 to be the last) | |
vndvi_rows_clipped_pooled_normalized = (vndvi_rows_clipped_pooled - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10) | |
vndvi_rows_clipped_pooled_dilated_normalized = (vndvi_rows_clipped_pooled_dilated - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10) | |
vndvi_interrows_clipped_pooled_normalized = (vndvi_interrows_clipped_pooled - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10) | |
# for visualization | |
vndvi_rows_img = alpha_composite( | |
raster, | |
vndvi_rows_clipped_pooled_dilated_normalized if dilate_rows else vndvi_rows_clipped_pooled_normalized, | |
opacity = 1.0, | |
colormap = 'RdYlGn', | |
alpha_image = np.zeros_like(raster[:,:,[0]]), | |
alpha_mask = mask_rows_dilated[...,None] if dilate_rows else mask_rows[...,None], | |
) # HW4 RGBA | |
vndvi_interrows_img = alpha_composite( | |
raster, | |
vndvi_interrows_clipped_pooled_normalized, | |
opacity = 1.0, | |
colormap = 'RdYlGn', | |
alpha_image = np.zeros_like(raster[:,:,[0]]), | |
alpha_mask = mask_interrows[...,None], | |
) # HW4 RGBA | |
# add colorbar | |
# fig_rows, ax = plt.subplots(1, 1, figsize=(10, 10)) | |
# divider = make_axes_locatable(ax) | |
# cax = divider.append_axes('right', size='5%', pad=0.15) | |
# ax.imshow(vndvi_rows_img) | |
# fig_rows.colorbar( | |
# mappable = mpl.cm.ScalarMappable( | |
# norm = mpl.colors.Normalize( | |
# vmin = vndvi_perc10, | |
# vmax = vndvi_perc90), | |
# cmap = 'RdYlGn'), | |
# cax = cax, | |
# orientation = 'vertical', | |
# label = 'vNDVI', | |
# shrink = 1) | |
# fig_interrows, ax = plt.subplots(1, 1, figsize=(10, 10)) | |
# divider = make_axes_locatable(ax) | |
# cax = divider.append_axes('right', size='5%', pad=0.15) | |
# ax.imshow(vndvi_interrows_img) | |
# fig_interrows.colorbar( | |
# mappable = mpl.cm.ScalarMappable( | |
# norm = mpl.colors.Normalize( | |
# vmin = vndvi_perc10, | |
# vmax = vndvi_perc90), | |
# cmap = 'RdYlGn'), | |
# cax = cax, | |
# orientation = 'vertical', | |
# label = 'vNDVI', | |
# shrink = 1) | |
# return fig_rows, fig_interrows | |
return vndvi_rows_img, vndvi_interrows_img | |
def compute_vdi( | |
raster: np.ndarray, | |
mask: np.ndarray, | |
window_size=360, | |
granularity=40, | |
): | |
# CHW -> HWC | |
raster = raster.transpose(1,2,0) | |
# Mask has values: 0=interrows, 255=rows, 1=nodata | |
# Get mask for the rows and interrows | |
mask_rows = (mask == 255) | |
mask_interrows = (mask == 0) | |
mask_valid = mask_rows | mask_interrows | |
# compute vdi | |
vdi, vdi_min, vdi_max = sliding_window_avg_pooling( | |
mask_rows[...,None], | |
window=window_size, | |
granularity=granularity, | |
alpha=mask_valid[...,None], | |
min_nonblank_pixels=0.9, | |
normalize=True, | |
return_min_max=True, | |
verbose=True, | |
) | |
# for visualization | |
vdi_img = alpha_composite( | |
raster, | |
vdi, | |
opacity = 1, | |
colormap = 'jet_r', | |
alpha_image = mask_valid[...,None], | |
alpha_mask = mask_valid[...,None], | |
) | |
# add colorbar | |
# fig, ax = plt.subplots(1, 1, figsize=(10, 10)) | |
# divider = make_axes_locatable(ax) | |
# cax = divider.append_axes('right', size='5%', pad=0.15) | |
# ax.imshow(vdi_img) | |
# fig.colorbar( | |
# mappable = mpl.cm.ScalarMappable( | |
# norm = mpl.colors.Normalize( | |
# vmin = vdi_min, | |
# vmax = vdi_max), | |
# cmap = 'jet_r'), | |
# cax = cax, | |
# orientation = 'vertical', | |
# label = 'VDI', | |
# shrink = 1) | |
# return fig | |
return vdi_img | |
def compute_mask( | |
raster: np.ndarray, | |
model: torch.nn.Module, | |
patch_size=512, | |
stride=256, | |
scaling_factor=None, | |
rotate=False, | |
batch_size=16 | |
): | |
assert isinstance(raster, np.ndarray), f'Input raster must be a numpy array. Got {type(raster)}' | |
assert len(raster.shape) == 3, f'Input raster must have 3 dimensions (bands, rows, cols). Got shape {raster.shape}' | |
assert raster.shape[0] in [3,4], f'Input raster must have 3 bands (RGB) or 4 bands (RGBA). Got {raster.shape[0]} bands' | |
assert isinstance(model, torch.nn.Module), 'Model must be a torch.nn.Module' | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Infer GSD | |
#gsd = abs(raster.rio.transform()[0]) # ground sampling distance (NB: valid only if image is a GeoTIFF) | |
# Growseg works best on orthoimages with gsd in [1, 1.7] cm/px. You may want to | |
# specify a scaling factor different from 1 if your image has a different gsd. | |
# E.g.: SCALING_FACTOR = gsd / 0.015 | |
# logger.info(f'Image GSD: {gsd*100:.2f} cm/px') | |
# scaling_factor = scaling_factor or (gsd / 0.015) | |
scaling_factor = scaling_factor or 1 | |
logger.info(f'Applying scaling factor: {scaling_factor:.2f}') | |
# segment | |
logger.info('Segmenting image...') | |
score_map = segment( | |
raster, | |
model, | |
patch_size=patch_size, | |
stride=stride, | |
scaling_factor=scaling_factor, | |
rotate=rotate, | |
device=device, | |
batch_size=batch_size, | |
verbose=True | |
) # mask is a HxW float32 array in [0, 1] | |
# apply threshold on confidence scores | |
alpha = (score_map == -1) | |
mask = (score_map > 0.5) | |
# convert to uint8 | |
mask = (mask * 255).astype(np.uint8) | |
# set nodata pixels to 1 | |
mask[alpha] = 1 | |
return mask |