import numpy as np import torch import rasterio import cv2 from transformers import SegformerForSemanticSegmentation from tqdm import tqdm from PIL import Image from scipy.ndimage import grey_dilation import matplotlib as mpl import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import make_axes_locatable from .viz_utils import alpha_composite def read_raster(path, order='CHW'): """Read a raster file and return a numpy array""" assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" with rasterio.open(path) as src: img = src.read() if order == 'HWC': img = np.moveaxis(img, 0, -1) return img def write_raster(path, img, profile, order='CHW'): """Write a numpy array to a raster file""" assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" if order == 'HWC': img = np.moveaxis(img, -1, 0) with rasterio.open(path, 'w', **profile) as dst: dst.write(img) def resize(img, shape=None, scaling_factor=1., order='CHW'): """Resize an image by a given scaling factor""" assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" assert shape is None or scaling_factor == 1., "Got both shape and scaling_factor. Please provide only one of them" # resize image if order == 'CHW': img = np.moveaxis(img, 0, -1) # CHW -> HWC if shape is not None: img = cv2.resize(img, shape[::-1], interpolation=cv2.INTER_LINEAR) else: img = cv2.resize(img, None, fx=scaling_factor, fy=scaling_factor, interpolation=cv2.INTER_LINEAR) # NB: cv2.resize returns a HW image if the input image is HW1: restore the C dimension if len(img.shape) == 2: img = img[..., None] if order == 'CHW': img = np.moveaxis(img, -1, 0) # HWC -> CHW return img def minimum_needed_padding(img_size, patch_size: int, stride: int): """ Compute the minimum padding needed to make an image divisible by a patch size with a given stride. Args: image_shape (tuple): the shape (H,W) of the image tensor patch_size (int): the size of the patches to extract stride (int): the stride to use when extracting patches Returns: tuple: the padding needed to make the image tensor divisible by the patch size with the given stride """ img_size = np.array(img_size) pad = np.where( img_size <= patch_size, (patch_size - img_size) % patch_size, # the % patch_size is to handle the case img_size = (0,0) (stride - (img_size - patch_size)) % stride ) pad_t, pad_l = pad // 2 pad_b, pad_r = pad[0] - pad_t, pad[1] - pad_l return pad_t, pad_b, pad_l, pad_r def pad(img, pad, order='CHW'): """Pad an image by the given pad values, in the format (pad_t, pad_b, pad_l, pad_r)""" assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" pad_t, pad_b, pad_l, pad_r = pad # pad image if order == 'HWC': padded_img = np.pad(img, ((pad_t,pad_b), (pad_l,pad_r), (0,0)), mode='constant', constant_values=0) # can also try mode='reflect' else: padded_img = np.pad(img, ((0,0), (pad_t,pad_b), (pad_l,pad_r)), mode='constant', constant_values=0) # can also try mode='reflect' if isinstance(img, torch.Tensor): padded_img = torch.tensor(padded_img) return padded_img def extract_patches(img, patch_size=512, stride=256, order='CHW', only_return_idx=True): """Extract patches from an image, in the format (h_start, h_end, w_start, w_end)""" assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" if order == 'HWC': H, W = img.shape[:2] else: H, W = img.shape[1:] # compute the number of patches n_patches = ((H - patch_size) // stride + 1) * ((W - patch_size) // stride + 1) # extract patches patches = [] patches_idx = [] for i in range(0, H-patch_size+1, stride): for j in range(0, W-patch_size+1, stride): patches_idx.append((i, i+patch_size, j, j+patch_size)) if not only_return_idx: if order == 'HWC': patch = img[i:i+patch_size, j:j+patch_size, :] else: patch = img[:, i:i+patch_size, j:j+patch_size] patches.append(patch) if only_return_idx: return patches_idx return patches, patches_idx def segment_batch(batch, model): # perform prediction with torch.no_grad(): out = model(batch) # (n_patches, 1, H, W) logits if isinstance(model, SegformerForSemanticSegmentation): out = upsample(out.logits, size=batch.shape[-2:]) # apply sigmoid out = torch.sigmoid(out) # logits -> confidence scores return out def upsample(x, size): """Upsample a 3D/4D/5D tensor""" return torch.nn.functional.interpolate(x, size=size, mode='bilinear', align_corners=False) def merge_patches(patches, patches_idx, rotate=False, canvas_shape=None, order='CHW'): # TODO """Merge patches into a single image""" assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" if rotate: axes_to_rotate = (0,1) if order == 'HWC' else (1,2) patches = [np.rot90(p, -i, axes=axes_to_rotate) for i,p in enumerate(patches)] else: assert len(patches) == len(patches_idx), f"Got {len(patches)} patches and {len(patches_idx)} indexes" # if canvas_shape is None, infer it from patches_idx if canvas_shape is None: patches_idx_zipped = list(zip(*patches_idx)) canvas_H = max(patches_idx_zipped[1]) canvas_W = max(patches_idx_zipped[3]) else: canvas_H, canvas_W = canvas_shape # initialize canvas dtype = patches[0].dtype if order == 'HWC': canvas_C = patches[0].shape[-1] canvas = np.zeros((canvas_H, canvas_W, canvas_C), dtype=dtype) # HWC n_overlapping_patches = np.zeros((canvas_H, canvas_W, 1)) else: canvas_C = patches[0].shape[0] canvas = np.zeros((canvas_C, canvas_H, canvas_W, ), dtype=dtype) # CHW n_overlapping_patches = np.zeros((1, canvas_H, canvas_W)) # merge patches for p, (t,b,l,r) in zip(patches, patches_idx): if order == 'HWC': canvas[t:b, l:r, :] += p n_overlapping_patches[t:b, l:r, 0] += 1 else: canvas[:, t:b, l:r] += p n_overlapping_patches[0, t:b, l:r] += 1 # compute average canvas = np.divide(canvas, n_overlapping_patches, where=(n_overlapping_patches != 0)) return canvas def segment(img, model, patch_size=512, stride=256, scaling_factor=1., rotate=False, device=None, batch_size=16, verbose=False): """Segment an RGB image by using a segmentation model. Returns a probability map (and performance metrics, if requested)""" # some checks assert isinstance(img, np.ndarray), f"Input must be a numpy array. Got {type(img)}" assert img.shape[0] in [3,4], f"Input image must be formatted as CHW, with C = 3,4. Got a shape of {img.shape}" assert img.dtype == np.uint8, f"Input image must be a numpy array with dtype np.uint8. Got {img.dtype}" # prepare model for evaluation model = model.to(device) model.eval() # prepare alpha channel original_shape = img.shape if img.shape[0] == 3: # create dummy alpha channel alpha = np.full(original_shape[1:], 255, dtype=np.uint8) else: # extract alpha channel img, alpha = img[:3], img[3] # resize image img = resize(img, scaling_factor=scaling_factor) # pad image pad_t, pad_b, pad_l, pad_r = minimum_needed_padding(img.shape[1:], patch_size, stride) padded_img = pad(img, pad=(pad_t, pad_b, pad_l, pad_r)) padded_shape = padded_img.shape # extract patches indexes patches_idx = extract_patches(padded_img, patch_size=patch_size, stride=stride) ### segment masks = [] masks_idx = [] batch = [] for i, p_idx in enumerate(tqdm(patches_idx, disable=not verbose, desc="Predicting...", total=len(patches_idx))): t, b, l, r = p_idx # extract patch patch = padded_img[:, t:b, l:r] # consider patch only if it is valid (i.e. not all black or all white) if np.any(patch != 0) and np.any(patch != 255): # convert patch to torch.tensor with float32 values in [0,1] (as required by torch) patch = torch.tensor(patch).float() / 255. # normalize patch with ImageNet mean and std patch = (patch - torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)) / torch.tensor([0.229, 0.224, 0.225]).view(3,1,1) # add patch to batch batch.append(patch) masks_idx.append(p_idx) # (optional) for each patch extracted, consider also its rotated versions if rotate: for rot in range(1,4): patch = torch.rot90(patch, rot, dims=[1,2]) batch.append(patch) masks_idx.append(p_idx) # if the batch is full, perform prediction if len(batch) >= batch_size or i == len(patches_idx)-1: # move batch to GPU batch = torch.stack(batch).to(device) # perform prediction out = segment_batch(batch, model) # append predictions to masks masks.append(out.cpu().numpy()) # reset batch batch = [] # concatenate predictions masks = np.concatenate(masks) # (n_patches, 1, H, W) # merge patches mask = merge_patches(masks, masks_idx, rotate=rotate, canvas_shape=padded_shape[1:]) # (1, H, W) # undo padding mask = mask[:, pad_t:padded_shape[1]-pad_b, pad_l:padded_shape[2]-pad_r] # resize mask to original shape mask = resize(mask, shape=original_shape[1:]) # apply alpha channel, i.e. set to -1 the pixels where alpha is 0 mask = np.where(alpha == 0, -1, mask) return mask.squeeze() def sliding_window_avg_pooling(img, window, granularity, alpha=None, min_nonblank_pixels=0., normalize=False, return_min_max=False, verbose=False): assert isinstance(img, np.ndarray), f'Input image must be a numpy array. Got {type(img)}' assert img.shape[2] == 1, f'Input image must be formatted as HWC, with C = 1. Got a shape of {img.shape}' # check if alpha channel was given, and cast it to np.float32 with values in [0,1] if alpha is not None: assert isinstance(alpha, np.ndarray), f'Alpha channel must be a numpy array. Got {type(alpha)}' assert alpha.shape[2] == 1, f'Alpha channel must be formatted as HWC, with C = 1. Got a shape of {alpha.shape}' assert img.shape == alpha.shape, f'The shape of input image {img.shape} and alpha channel {alpha.shape} do not match' if alpha.dtype == np.uint8: alpha = (alpha / 255).astype(np.float32) elif alpha.dtype == bool: alpha = alpha.astype(np.float32) else: alpha = np.ones_like(img) # extract patches patches, patches_idx = extract_patches(img, patch_size=window, stride=granularity, order='HWC', only_return_idx=False) patches_alpha, _ = extract_patches(alpha, patch_size=window, stride=granularity, order='HWC', only_return_idx=False) # keep only patches with more than min_nonblank_pixels kept_patches = [] for i, p_a in tqdm(enumerate(patches_alpha), total=len(patches), disable=not verbose): if p_a.sum() > min_nonblank_pixels * window**2: kept_patches.append(i) patches = [patches[i] for i in kept_patches] patches_idx = [patches_idx[i] for i in kept_patches] patches_alpha = [patches_alpha[i] for i in kept_patches] # compute average patch value (i.e. density inside the patch) patches_density = [np.full_like(p_a, (p * p_a).sum() / p_a.sum()) for p, p_a in zip(patches, patches_alpha)] # merge patches pooled_img = merge_patches(patches_density, patches_idx, canvas_shape=img.shape[:2], order='HWC') # apply alpha pooled_img = pooled_img * alpha if normalize: # [0,1]-normalize pooled_img_min = pooled_img.min() pooled_img_max = pooled_img.max() pooled_img = (pooled_img - pooled_img_min) / (pooled_img_max - pooled_img_min) if return_min_max: return pooled_img, pooled_img_min, pooled_img_max return pooled_img def compute_vndvi(image, mask, dilate_rows=True, window_size=360): assert image.dtype == np.uint8, f"Input image must be a numpy array with dtype np.uint8. Got {image.dtype}" assert mask.dtype == np.uint8, f"Input mask must be a numpy array with dtype np.uint8. Got {mask.dtype}" # CHW -> HWC image = image.transpose(1,2,0) # extract channels _image = image.astype(np.float32) / 255 # convert to float32 in [0,1] R, G, B = _image[:,:,0], _image[:,:,1], _image[:,:,2] # to avoid division by 0 due to negative power, we replace 0 with 1 in R and B channels R = np.where(R == 0, 1, R) B = np.where(B == 0, 1, B) # compute vndvi vndvi = 0.5268 * (R**(-0.1294) * G**(0.3389) * B**(-0.3118)) # clip values to [0,1] vndvi = np.clip(vndvi, 0, 1) # compute vndvi rows heatmap #vndvi_rows = np.where(mask == 255, vndvi, np.nan) # compute vndvi interrows heatmap #vndvi_interrows = np.where(mask == 0, vndvi, np.nan) # compute 10th and 90th percentile on whole vineyard vndvi heatmap vndvi_perc10, vndvi_perc90 = np.percentile(vndvi[mask != 1], [10,90]) # mask is 1 for nodata, 0 or 255 for valid pixels # clip values between 10th and 90th percentile vndvi_clipped = np.clip(vndvi, vndvi_perc10, vndvi_perc90) # perform sliding window average pooling to smooth the heatmap # NB: the window takes into account only the rows vndvi_rows_clipped_pooled = sliding_window_avg_pooling( np.where(mask == 255, vndvi_clipped, 0)[...,None], window = int(window_size / 4), granularity = 10, alpha = (mask == 255)[...,None], min_nonblank_pixels = 0.0, ) # same, but for interrows vndvi_interrows_clipped_pooled = sliding_window_avg_pooling( np.where(mask == 0, vndvi_clipped, 0)[...,None], window = int(window_size / 4), granularity = 10, alpha = (mask == 0)[...,None], min_nonblank_pixels = 0.0, ) # apply dilation to rows mask dilate_rows = True if dilate_rows: dil_factor = int(window_size / 60) mask_rows_dilated = grey_dilation(mask == 255, size=(dil_factor,dil_factor)) vndvi_rows_clipped_pooled_dilated = grey_dilation(vndvi_rows_clipped_pooled, size=(dil_factor,dil_factor,1)) # for visualization purposes, normalize with vndvi_perc10 and # vndvi_perc90 (because we want vndvi_perc10 to be the first color of # the colormap and vndvi_perc90 to be the last) vndvi_rows_clipped_pooled_normalized = (vndvi_rows_clipped_pooled - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10) vndvi_rows_clipped_pooled_dilated_normalized = (vndvi_rows_clipped_pooled_dilated - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10) vndvi_interrows_clipped_pooled_normalized = (vndvi_interrows_clipped_pooled - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10) # for visualization vndvi_rows_img = alpha_composite( image, vndvi_rows_clipped_pooled_dilated_normalized if dilate_rows else vndvi_rows_clipped_pooled_normalized, opacity = 1.0, colormap = 'RdYlGn', alpha_image = np.zeros_like(image[:,:,[0]]), alpha_mask = mask_rows_dilated[...,None] if dilate_rows else (mask == 255)[...,None], ) vndvi_interrows_img = alpha_composite( image, vndvi_interrows_clipped_pooled_normalized, opacity = 1.0, colormap = 'RdYlGn', alpha_image = np.zeros_like(image[:,:,[0]]), alpha_mask = (mask == 0)[...,None], ) # add colorbar fig_rows, ax = plt.subplots(1, 1, figsize=(10, 10)) divider = make_axes_locatable(ax) cax = divider.append_axes('right', size='5%', pad=0.15) ax.imshow(vndvi_rows_img) fig_rows.colorbar( mappable = mpl.cm.ScalarMappable( norm = mpl.colors.Normalize( vmin = vndvi_perc10, vmax = vndvi_perc90), cmap = 'RdYlGn'), cax = cax, orientation = 'vertical', label = 'vNDVI', shrink = 1) fig_interrows, ax = plt.subplots(1, 1, figsize=(10, 10)) divider = make_axes_locatable(ax) cax = divider.append_axes('right', size='5%', pad=0.15) ax.imshow(vndvi_interrows_img) fig_interrows.colorbar( mappable = mpl.cm.ScalarMappable( norm = mpl.colors.Normalize( vmin = vndvi_perc10, vmax = vndvi_perc90), cmap = 'RdYlGn'), cax = cax, orientation = 'vertical', label = 'vNDVI', shrink = 1) return fig_rows, fig_interrows def compute_vdi(image, mask, window_size=360): assert image.dtype == np.uint8, f"Input image must be a numpy array with dtype np.uint8. Got {image.dtype}" assert mask.dtype == np.uint8, f"Input mask must be a numpy array with dtype np.uint8. Got {mask.dtype}" # CHW -> HWC image = image.transpose(1,2,0) # compute vdi vdi, vdi_min, vdi_max = sliding_window_avg_pooling( (mask == 255)[...,None], window = window_size, granularity = 10, alpha = (mask != 1)[...,None], # mask is 1 for nodata, 0 or 255 for valid pixels min_nonblank_pixels = 0.9, normalize=True, return_min_max=True ) # for visualization vdi_img = alpha_composite( image, vdi, opacity = 0.5, colormap = 'jet_r', alpha_image = (mask != 1)[...,None], alpha_mask = (mask != 1)[...,None], ) # add colorbar fig, ax = plt.subplots(1, 1, figsize=(10, 10)) divider = make_axes_locatable(ax) cax = divider.append_axes('right', size='5%', pad=0.15) ax.imshow(vdi_img) fig.colorbar( mappable = mpl.cm.ScalarMappable( norm = mpl.colors.Normalize( vmin = vdi_min, vmax = vdi_max), cmap = 'jet_r'), cax = cax, orientation = 'vertical', label = 'VDI', shrink = 1) return fig