import numpy as np
import torch
import rasterio
import cv2
from transformers import SegformerForSemanticSegmentation
from tqdm import tqdm
from PIL import Image
from scipy.ndimage import grey_dilation
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from .viz_utils import alpha_composite
def read_raster(path, order='CHW'):
"""Read a raster file and return a numpy array"""
assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
with as src:
img =
if order == 'HWC':
img = np.moveaxis(img, 0, -1)
return img
def write_raster(path, img, profile, order='CHW'):
"""Write a numpy array to a raster file"""
assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
if order == 'HWC':
img = np.moveaxis(img, -1, 0)
with, 'w', **profile) as dst:
def resize(img, shape=None, scaling_factor=1., order='CHW'):
"""Resize an image by a given scaling factor"""
assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
assert shape is None or scaling_factor == 1., "Got both shape and scaling_factor. Please provide only one of them"
# resize image
if order == 'CHW':
img = np.moveaxis(img, 0, -1) # CHW -> HWC
if shape is not None:
img = cv2.resize(img, shape[::-1], interpolation=cv2.INTER_LINEAR)
img = cv2.resize(img, None, fx=scaling_factor, fy=scaling_factor, interpolation=cv2.INTER_LINEAR)
# NB: cv2.resize returns a HW image if the input image is HW1: restore the C dimension
if len(img.shape) == 2:
img = img[..., None]
if order == 'CHW':
img = np.moveaxis(img, -1, 0) # HWC -> CHW
return img
def minimum_needed_padding(img_size, patch_size: int, stride: int):
Compute the minimum padding needed to make an image divisible by a patch size with a given stride.
image_shape (tuple): the shape (H,W) of the image tensor
patch_size (int): the size of the patches to extract
stride (int): the stride to use when extracting patches
tuple: the padding needed to make the image tensor divisible by the patch size with the given stride
img_size = np.array(img_size)
pad = np.where(
img_size <= patch_size,
(patch_size - img_size) % patch_size, # the % patch_size is to handle the case img_size = (0,0)
(stride - (img_size - patch_size)) % stride
pad_t, pad_l = pad // 2
pad_b, pad_r = pad[0] - pad_t, pad[1] - pad_l
return pad_t, pad_b, pad_l, pad_r
def pad(img, pad, order='CHW'):
"""Pad an image by the given pad values, in the format (pad_t, pad_b, pad_l, pad_r)"""
assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
pad_t, pad_b, pad_l, pad_r = pad
# pad image
if order == 'HWC':
padded_img = np.pad(img, ((pad_t,pad_b), (pad_l,pad_r), (0,0)), mode='constant', constant_values=0) # can also try mode='reflect'
padded_img = np.pad(img, ((0,0), (pad_t,pad_b), (pad_l,pad_r)), mode='constant', constant_values=0) # can also try mode='reflect'
if isinstance(img, torch.Tensor):
padded_img = torch.tensor(padded_img)
return padded_img
def extract_patches(img, patch_size=512, stride=256, order='CHW', only_return_idx=True):
"""Extract patches from an image, in the format (h_start, h_end, w_start, w_end)"""
assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
if order == 'HWC':
H, W = img.shape[:2]
H, W = img.shape[1:]
# compute the number of patches
n_patches = ((H - patch_size) // stride + 1) * ((W - patch_size) // stride + 1)
# extract patches
patches = []
patches_idx = []
for i in range(0, H-patch_size+1, stride):
for j in range(0, W-patch_size+1, stride):
patches_idx.append((i, i+patch_size, j, j+patch_size))
if not only_return_idx:
if order == 'HWC':
patch = img[i:i+patch_size, j:j+patch_size, :]
patch = img[:, i:i+patch_size, j:j+patch_size]
if only_return_idx:
return patches_idx
return patches, patches_idx
def segment_batch(batch, model):
# perform prediction
with torch.no_grad():
out = model(batch) # (n_patches, 1, H, W) logits
if isinstance(model, SegformerForSemanticSegmentation):
out = upsample(out.logits, size=batch.shape[-2:])
# apply sigmoid
out = torch.sigmoid(out) # logits -> confidence scores
return out
def upsample(x, size):
"""Upsample a 3D/4D/5D tensor"""
return torch.nn.functional.interpolate(x, size=size, mode='bilinear', align_corners=False)
def merge_patches(patches, patches_idx, rotate=False, canvas_shape=None, order='CHW'): # TODO
"""Merge patches into a single image"""
assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
if rotate:
axes_to_rotate = (0,1) if order == 'HWC' else (1,2)
patches = [np.rot90(p, -i, axes=axes_to_rotate) for i,p in enumerate(patches)]
assert len(patches) == len(patches_idx), f"Got {len(patches)} patches and {len(patches_idx)} indexes"
# if canvas_shape is None, infer it from patches_idx
if canvas_shape is None:
patches_idx_zipped = list(zip(*patches_idx))
canvas_H = max(patches_idx_zipped[1])
canvas_W = max(patches_idx_zipped[3])
canvas_H, canvas_W = canvas_shape
# initialize canvas
dtype = patches[0].dtype
if order == 'HWC':
canvas_C = patches[0].shape[-1]
canvas = np.zeros((canvas_H, canvas_W, canvas_C), dtype=dtype) # HWC
n_overlapping_patches = np.zeros((canvas_H, canvas_W, 1))
canvas_C = patches[0].shape[0]
canvas = np.zeros((canvas_C, canvas_H, canvas_W, ), dtype=dtype) # CHW
n_overlapping_patches = np.zeros((1, canvas_H, canvas_W))
# merge patches
for p, (t,b,l,r) in zip(patches, patches_idx):
if order == 'HWC':
canvas[t:b, l:r, :] += p
n_overlapping_patches[t:b, l:r, 0] += 1
canvas[:, t:b, l:r] += p
n_overlapping_patches[0, t:b, l:r] += 1
# compute average
canvas = np.divide(canvas, n_overlapping_patches, where=(n_overlapping_patches != 0))
return canvas
def segment(img, model, patch_size=512, stride=256, scaling_factor=1., rotate=False, device=None, batch_size=16, verbose=False):
"""Segment an RGB image by using a segmentation model. Returns a probability
map (and performance metrics, if requested)"""
# some checks
assert isinstance(img, np.ndarray), f"Input must be a numpy array. Got {type(img)}"
assert img.shape[0] in [3,4], f"Input image must be formatted as CHW, with C = 3,4. Got a shape of {img.shape}"
assert img.dtype == np.uint8, f"Input image must be a numpy array with dtype np.uint8. Got {img.dtype}"
# prepare model for evaluation
model =
# prepare alpha channel
original_shape = img.shape
if img.shape[0] == 3:
# create dummy alpha channel
alpha = np.full(original_shape[1:], 255, dtype=np.uint8)
# extract alpha channel
img, alpha = img[:3], img[3]
# resize image
img = resize(img, scaling_factor=scaling_factor)
# pad image
pad_t, pad_b, pad_l, pad_r = minimum_needed_padding(img.shape[1:], patch_size, stride)
padded_img = pad(img, pad=(pad_t, pad_b, pad_l, pad_r))
padded_shape = padded_img.shape
# extract patches indexes
patches_idx = extract_patches(padded_img, patch_size=patch_size, stride=stride)
### segment
masks = []
masks_idx = []
batch = []
for i, p_idx in enumerate(tqdm(patches_idx, disable=not verbose, desc="Predicting...", total=len(patches_idx))):
t, b, l, r = p_idx
# extract patch
patch = padded_img[:, t:b, l:r]
# consider patch only if it is valid (i.e. not all black or all white)
if np.any(patch != 0) and np.any(patch != 255):
# convert patch to torch.tensor with float32 values in [0,1] (as required by torch)
patch = torch.tensor(patch).float() / 255.
# normalize patch with ImageNet mean and std
patch = (patch - torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)) / torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
# add patch to batch
# (optional) for each patch extracted, consider also its rotated versions
if rotate:
for rot in range(1,4):
patch = torch.rot90(patch, rot, dims=[1,2])
# if the batch is full, perform prediction
if len(batch) >= batch_size or i == len(patches_idx)-1:
# move batch to GPU
batch = torch.stack(batch).to(device)
# perform prediction
out = segment_batch(batch, model)
# append predictions to masks
# reset batch
batch = []
# concatenate predictions
masks = np.concatenate(masks) # (n_patches, 1, H, W)
# merge patches
mask = merge_patches(masks, masks_idx, rotate=rotate, canvas_shape=padded_shape[1:]) # (1, H, W)
# undo padding
mask = mask[:, pad_t:padded_shape[1]-pad_b, pad_l:padded_shape[2]-pad_r]
# resize mask to original shape
mask = resize(mask, shape=original_shape[1:])
# apply alpha channel, i.e. set to -1 the pixels where alpha is 0
mask = np.where(alpha == 0, -1, mask)
return mask.squeeze()
def sliding_window_avg_pooling(img, window, granularity, alpha=None, min_nonblank_pixels=0., normalize=False, return_min_max=False, verbose=False):
assert isinstance(img, np.ndarray), f'Input image must be a numpy array. Got {type(img)}'
assert img.shape[2] == 1, f'Input image must be formatted as HWC, with C = 1. Got a shape of {img.shape}'
# check if alpha channel was given, and cast it to np.float32 with values in [0,1]
if alpha is not None:
assert isinstance(alpha, np.ndarray), f'Alpha channel must be a numpy array. Got {type(alpha)}'
assert alpha.shape[2] == 1, f'Alpha channel must be formatted as HWC, with C = 1. Got a shape of {alpha.shape}'
assert img.shape == alpha.shape, f'The shape of input image {img.shape} and alpha channel {alpha.shape} do not match'
if alpha.dtype == np.uint8:
alpha = (alpha / 255).astype(np.float32)
elif alpha.dtype == bool:
alpha = alpha.astype(np.float32)
alpha = np.ones_like(img)
# extract patches
patches, patches_idx = extract_patches(img, patch_size=window, stride=granularity, order='HWC', only_return_idx=False)
patches_alpha, _ = extract_patches(alpha, patch_size=window, stride=granularity, order='HWC', only_return_idx=False)
# keep only patches with more than min_nonblank_pixels
kept_patches = []
for i, p_a in tqdm(enumerate(patches_alpha), total=len(patches), disable=not verbose):
if p_a.sum() > min_nonblank_pixels * window**2:
patches = [patches[i] for i in kept_patches]
patches_idx = [patches_idx[i] for i in kept_patches]
patches_alpha = [patches_alpha[i] for i in kept_patches]
# compute average patch value (i.e. density inside the patch)
patches_density = [np.full_like(p_a, (p * p_a).sum() / p_a.sum()) for p, p_a in zip(patches, patches_alpha)]
# merge patches
pooled_img = merge_patches(patches_density, patches_idx, canvas_shape=img.shape[:2], order='HWC')
# apply alpha
pooled_img = pooled_img * alpha
if normalize:
# [0,1]-normalize
pooled_img_min = pooled_img.min()
pooled_img_max = pooled_img.max()
pooled_img = (pooled_img - pooled_img_min) / (pooled_img_max - pooled_img_min)
if return_min_max:
return pooled_img, pooled_img_min, pooled_img_max
return pooled_img
def compute_vndvi(image, mask, dilate_rows=True, window_size=360):
assert image.dtype == np.uint8, f"Input image must be a numpy array with dtype np.uint8. Got {image.dtype}"
assert mask.dtype == np.uint8, f"Input mask must be a numpy array with dtype np.uint8. Got {mask.dtype}"
# CHW -> HWC
image = image.transpose(1,2,0)
# extract channels
_image = image.astype(np.float32) / 255 # convert to float32 in [0,1]
R, G, B = _image[:,:,0], _image[:,:,1], _image[:,:,2]
# to avoid division by 0 due to negative power, we replace 0 with 1 in R and B channels
R = np.where(R == 0, 1, R)
B = np.where(B == 0, 1, B)
# compute vndvi
vndvi = 0.5268 * (R**(-0.1294) * G**(0.3389) * B**(-0.3118))
# clip values to [0,1]
vndvi = np.clip(vndvi, 0, 1)
# compute vndvi rows heatmap
#vndvi_rows = np.where(mask == 255, vndvi, np.nan)
# compute vndvi interrows heatmap
#vndvi_interrows = np.where(mask == 0, vndvi, np.nan)
# compute 10th and 90th percentile on whole vineyard vndvi heatmap
vndvi_perc10, vndvi_perc90 = np.percentile(vndvi[mask != 1], [10,90]) # mask is 1 for nodata, 0 or 255 for valid pixels
# clip values between 10th and 90th percentile
vndvi_clipped = np.clip(vndvi, vndvi_perc10, vndvi_perc90)
# perform sliding window average pooling to smooth the heatmap
# NB: the window takes into account only the rows
vndvi_rows_clipped_pooled = sliding_window_avg_pooling(
np.where(mask == 255, vndvi_clipped, 0)[...,None],
window = int(window_size / 4),
granularity = 10,
alpha = (mask == 255)[...,None],
min_nonblank_pixels = 0.0,
# same, but for interrows
vndvi_interrows_clipped_pooled = sliding_window_avg_pooling(
np.where(mask == 0, vndvi_clipped, 0)[...,None],
window = int(window_size / 4),
granularity = 10,
alpha = (mask == 0)[...,None],
min_nonblank_pixels = 0.0,
# apply dilation to rows mask
dilate_rows = True
if dilate_rows:
dil_factor = int(window_size / 60)
mask_rows_dilated = grey_dilation(mask == 255, size=(dil_factor,dil_factor))
vndvi_rows_clipped_pooled_dilated = grey_dilation(vndvi_rows_clipped_pooled, size=(dil_factor,dil_factor,1))
# for visualization purposes, normalize with vndvi_perc10 and
# vndvi_perc90 (because we want vndvi_perc10 to be the first color of
# the colormap and vndvi_perc90 to be the last)
vndvi_rows_clipped_pooled_normalized = (vndvi_rows_clipped_pooled - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10)
vndvi_rows_clipped_pooled_dilated_normalized = (vndvi_rows_clipped_pooled_dilated - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10)
vndvi_interrows_clipped_pooled_normalized = (vndvi_interrows_clipped_pooled - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10)
# for visualization
vndvi_rows_img = alpha_composite(
vndvi_rows_clipped_pooled_dilated_normalized if dilate_rows else vndvi_rows_clipped_pooled_normalized,
opacity = 1.0,
colormap = 'RdYlGn',
alpha_image = np.zeros_like(image[:,:,[0]]),
alpha_mask = mask_rows_dilated[...,None] if dilate_rows else (mask == 255)[...,None],
vndvi_interrows_img = alpha_composite(
opacity = 1.0,
colormap = 'RdYlGn',
alpha_image = np.zeros_like(image[:,:,[0]]),
alpha_mask = (mask == 0)[...,None],
# add colorbar
fig_rows, ax = plt.subplots(1, 1, figsize=(10, 10))
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.15)
mappable =
norm = mpl.colors.Normalize(
vmin = vndvi_perc10,
vmax = vndvi_perc90),
cmap = 'RdYlGn'),
cax = cax,
orientation = 'vertical',
label = 'vNDVI',
shrink = 1)
fig_interrows, ax = plt.subplots(1, 1, figsize=(10, 10))
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.15)
mappable =
norm = mpl.colors.Normalize(
vmin = vndvi_perc10,
vmax = vndvi_perc90),
cmap = 'RdYlGn'),
cax = cax,
orientation = 'vertical',
label = 'vNDVI',
shrink = 1)
return fig_rows, fig_interrows
def compute_vdi(image, mask, window_size=360):
assert image.dtype == np.uint8, f"Input image must be a numpy array with dtype np.uint8. Got {image.dtype}"
assert mask.dtype == np.uint8, f"Input mask must be a numpy array with dtype np.uint8. Got {mask.dtype}"
# CHW -> HWC
image = image.transpose(1,2,0)
# compute vdi
vdi, vdi_min, vdi_max = sliding_window_avg_pooling(
(mask == 255)[...,None],
window = window_size,
granularity = 10,
alpha = (mask != 1)[...,None], # mask is 1 for nodata, 0 or 255 for valid pixels
min_nonblank_pixels = 0.9,
# for visualization
vdi_img = alpha_composite(
opacity = 0.5,
colormap = 'jet_r',
alpha_image = (mask != 1)[...,None],
alpha_mask = (mask != 1)[...,None],
# add colorbar
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.15)
mappable =
norm = mpl.colors.Normalize(
vmin = vdi_min,
vmax = vdi_max),
cmap = 'jet_r'),
cax = cax,
orientation = 'vertical',
label = 'VDI',
shrink = 1)
return fig