Spaces:
Running
Running
import numpy as np | |
import torch | |
import rasterio | |
import cv2 | |
from transformers import SegformerForSemanticSegmentation | |
from tqdm import tqdm | |
from PIL import Image | |
from scipy.ndimage import grey_dilation | |
import matplotlib as mpl | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.axes_grid1 import make_axes_locatable | |
from .viz_utils import alpha_composite | |
def read_raster(path, order='CHW'): | |
"""Read a raster file and return a numpy array""" | |
assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" | |
with rasterio.open(path) as src: | |
img = src.read() | |
if order == 'HWC': | |
img = np.moveaxis(img, 0, -1) | |
return img | |
def write_raster(path, img, profile, order='CHW'): | |
"""Write a numpy array to a raster file""" | |
assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" | |
if order == 'HWC': | |
img = np.moveaxis(img, -1, 0) | |
with rasterio.open(path, 'w', **profile) as dst: | |
dst.write(img) | |
def resize(img, shape=None, scaling_factor=1., order='CHW'): | |
"""Resize an image by a given scaling factor""" | |
assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" | |
assert shape is None or scaling_factor == 1., "Got both shape and scaling_factor. Please provide only one of them" | |
# resize image | |
if order == 'CHW': | |
img = np.moveaxis(img, 0, -1) # CHW -> HWC | |
if shape is not None: | |
img = cv2.resize(img, shape[::-1], interpolation=cv2.INTER_LINEAR) | |
else: | |
img = cv2.resize(img, None, fx=scaling_factor, fy=scaling_factor, interpolation=cv2.INTER_LINEAR) | |
# NB: cv2.resize returns a HW image if the input image is HW1: restore the C dimension | |
if len(img.shape) == 2: | |
img = img[..., None] | |
if order == 'CHW': | |
img = np.moveaxis(img, -1, 0) # HWC -> CHW | |
return img | |
def minimum_needed_padding(img_size, patch_size: int, stride: int): | |
""" | |
Compute the minimum padding needed to make an image divisible by a patch size with a given stride. | |
Args: | |
image_shape (tuple): the shape (H,W) of the image tensor | |
patch_size (int): the size of the patches to extract | |
stride (int): the stride to use when extracting patches | |
Returns: | |
tuple: the padding needed to make the image tensor divisible by the patch size with the given stride | |
""" | |
img_size = np.array(img_size) | |
pad = np.where( | |
img_size <= patch_size, | |
(patch_size - img_size) % patch_size, # the % patch_size is to handle the case img_size = (0,0) | |
(stride - (img_size - patch_size)) % stride | |
) | |
pad_t, pad_l = pad // 2 | |
pad_b, pad_r = pad[0] - pad_t, pad[1] - pad_l | |
return pad_t, pad_b, pad_l, pad_r | |
def pad(img, pad, order='CHW'): | |
"""Pad an image by the given pad values, in the format (pad_t, pad_b, pad_l, pad_r)""" | |
assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" | |
pad_t, pad_b, pad_l, pad_r = pad | |
# pad image | |
if order == 'HWC': | |
padded_img = np.pad(img, ((pad_t,pad_b), (pad_l,pad_r), (0,0)), mode='constant', constant_values=0) # can also try mode='reflect' | |
else: | |
padded_img = np.pad(img, ((0,0), (pad_t,pad_b), (pad_l,pad_r)), mode='constant', constant_values=0) # can also try mode='reflect' | |
if isinstance(img, torch.Tensor): | |
padded_img = torch.tensor(padded_img) | |
return padded_img | |
def extract_patches(img, patch_size=512, stride=256, order='CHW', only_return_idx=True): | |
"""Extract patches from an image, in the format (h_start, h_end, w_start, w_end)""" | |
assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" | |
if order == 'HWC': | |
H, W = img.shape[:2] | |
else: | |
H, W = img.shape[1:] | |
# compute the number of patches | |
n_patches = ((H - patch_size) // stride + 1) * ((W - patch_size) // stride + 1) | |
# extract patches | |
patches = [] | |
patches_idx = [] | |
for i in range(0, H-patch_size+1, stride): | |
for j in range(0, W-patch_size+1, stride): | |
patches_idx.append((i, i+patch_size, j, j+patch_size)) | |
if not only_return_idx: | |
if order == 'HWC': | |
patch = img[i:i+patch_size, j:j+patch_size, :] | |
else: | |
patch = img[:, i:i+patch_size, j:j+patch_size] | |
patches.append(patch) | |
if only_return_idx: | |
return patches_idx | |
return patches, patches_idx | |
def segment_batch(batch, model): | |
# perform prediction | |
with torch.no_grad(): | |
out = model(batch) # (n_patches, 1, H, W) logits | |
if isinstance(model, SegformerForSemanticSegmentation): | |
out = upsample(out.logits, size=batch.shape[-2:]) | |
# apply sigmoid | |
out = torch.sigmoid(out) # logits -> confidence scores | |
return out | |
def upsample(x, size): | |
"""Upsample a 3D/4D/5D tensor""" | |
return torch.nn.functional.interpolate(x, size=size, mode='bilinear', align_corners=False) | |
def merge_patches(patches, patches_idx, rotate=False, canvas_shape=None, order='CHW'): # TODO | |
"""Merge patches into a single image""" | |
assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" | |
if rotate: | |
axes_to_rotate = (0,1) if order == 'HWC' else (1,2) | |
patches = [np.rot90(p, -i, axes=axes_to_rotate) for i,p in enumerate(patches)] | |
else: | |
assert len(patches) == len(patches_idx), f"Got {len(patches)} patches and {len(patches_idx)} indexes" | |
# if canvas_shape is None, infer it from patches_idx | |
if canvas_shape is None: | |
patches_idx_zipped = list(zip(*patches_idx)) | |
canvas_H = max(patches_idx_zipped[1]) | |
canvas_W = max(patches_idx_zipped[3]) | |
else: | |
canvas_H, canvas_W = canvas_shape | |
# initialize canvas | |
dtype = patches[0].dtype | |
if order == 'HWC': | |
canvas_C = patches[0].shape[-1] | |
canvas = np.zeros((canvas_H, canvas_W, canvas_C), dtype=dtype) # HWC | |
n_overlapping_patches = np.zeros((canvas_H, canvas_W, 1)) | |
else: | |
canvas_C = patches[0].shape[0] | |
canvas = np.zeros((canvas_C, canvas_H, canvas_W, ), dtype=dtype) # CHW | |
n_overlapping_patches = np.zeros((1, canvas_H, canvas_W)) | |
# merge patches | |
for p, (t,b,l,r) in zip(patches, patches_idx): | |
if order == 'HWC': | |
canvas[t:b, l:r, :] += p | |
n_overlapping_patches[t:b, l:r, 0] += 1 | |
else: | |
canvas[:, t:b, l:r] += p | |
n_overlapping_patches[0, t:b, l:r] += 1 | |
# compute average | |
canvas = np.divide(canvas, n_overlapping_patches, where=(n_overlapping_patches != 0)) | |
return canvas | |
def segment(img, model, patch_size=512, stride=256, scaling_factor=1., rotate=False, device=None, batch_size=16, verbose=False): | |
"""Segment an RGB image by using a segmentation model. Returns a probability | |
map (and performance metrics, if requested)""" | |
# some checks | |
assert isinstance(img, np.ndarray), f"Input must be a numpy array. Got {type(img)}" | |
assert img.shape[0] in [3,4], f"Input image must be formatted as CHW, with C = 3,4. Got a shape of {img.shape}" | |
assert img.dtype == np.uint8, f"Input image must be a numpy array with dtype np.uint8. Got {img.dtype}" | |
# prepare model for evaluation | |
model = model.to(device) | |
model.eval() | |
# prepare alpha channel | |
original_shape = img.shape | |
if img.shape[0] == 3: | |
# create dummy alpha channel | |
alpha = np.full(original_shape[1:], 255, dtype=np.uint8) | |
else: | |
# extract alpha channel | |
img, alpha = img[:3], img[3] | |
# resize image | |
img = resize(img, scaling_factor=scaling_factor) | |
# pad image | |
pad_t, pad_b, pad_l, pad_r = minimum_needed_padding(img.shape[1:], patch_size, stride) | |
padded_img = pad(img, pad=(pad_t, pad_b, pad_l, pad_r)) | |
padded_shape = padded_img.shape | |
# extract patches indexes | |
patches_idx = extract_patches(padded_img, patch_size=patch_size, stride=stride) | |
### segment | |
masks = [] | |
masks_idx = [] | |
batch = [] | |
for i, p_idx in enumerate(tqdm(patches_idx, disable=not verbose, desc="Predicting...", total=len(patches_idx))): | |
t, b, l, r = p_idx | |
# extract patch | |
patch = padded_img[:, t:b, l:r] | |
# consider patch only if it is valid (i.e. not all black or all white) | |
if np.any(patch != 0) and np.any(patch != 255): | |
# convert patch to torch.tensor with float32 values in [0,1] (as required by torch) | |
patch = torch.tensor(patch).float() / 255. | |
# normalize patch with ImageNet mean and std | |
patch = (patch - torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)) / torch.tensor([0.229, 0.224, 0.225]).view(3,1,1) | |
# add patch to batch | |
batch.append(patch) | |
masks_idx.append(p_idx) | |
# (optional) for each patch extracted, consider also its rotated versions | |
if rotate: | |
for rot in range(1,4): | |
patch = torch.rot90(patch, rot, dims=[1,2]) | |
batch.append(patch) | |
masks_idx.append(p_idx) | |
# if the batch is full, perform prediction | |
if len(batch) >= batch_size or i == len(patches_idx)-1: | |
# move batch to GPU | |
batch = torch.stack(batch).to(device) | |
# perform prediction | |
out = segment_batch(batch, model) | |
# append predictions to masks | |
masks.append(out.cpu().numpy()) | |
# reset batch | |
batch = [] | |
# concatenate predictions | |
masks = np.concatenate(masks) # (n_patches, 1, H, W) | |
# merge patches | |
mask = merge_patches(masks, masks_idx, rotate=rotate, canvas_shape=padded_shape[1:]) # (1, H, W) | |
# undo padding | |
mask = mask[:, pad_t:padded_shape[1]-pad_b, pad_l:padded_shape[2]-pad_r] | |
# resize mask to original shape | |
mask = resize(mask, shape=original_shape[1:]) | |
# apply alpha channel, i.e. set to -1 the pixels where alpha is 0 | |
mask = np.where(alpha == 0, -1, mask) | |
return mask.squeeze() | |
def sliding_window_avg_pooling(img, window, granularity, alpha=None, min_nonblank_pixels=0., normalize=False, return_min_max=False, verbose=False): | |
assert isinstance(img, np.ndarray), f'Input image must be a numpy array. Got {type(img)}' | |
assert img.shape[2] == 1, f'Input image must be formatted as HWC, with C = 1. Got a shape of {img.shape}' | |
# check if alpha channel was given, and cast it to np.float32 with values in [0,1] | |
if alpha is not None: | |
assert isinstance(alpha, np.ndarray), f'Alpha channel must be a numpy array. Got {type(alpha)}' | |
assert alpha.shape[2] == 1, f'Alpha channel must be formatted as HWC, with C = 1. Got a shape of {alpha.shape}' | |
assert img.shape == alpha.shape, f'The shape of input image {img.shape} and alpha channel {alpha.shape} do not match' | |
if alpha.dtype == np.uint8: | |
alpha = (alpha / 255).astype(np.float32) | |
elif alpha.dtype == bool: | |
alpha = alpha.astype(np.float32) | |
else: | |
alpha = np.ones_like(img) | |
# extract patches | |
patches, patches_idx = extract_patches(img, patch_size=window, stride=granularity, order='HWC', only_return_idx=False) | |
patches_alpha, _ = extract_patches(alpha, patch_size=window, stride=granularity, order='HWC', only_return_idx=False) | |
# keep only patches with more than min_nonblank_pixels | |
kept_patches = [] | |
for i, p_a in tqdm(enumerate(patches_alpha), total=len(patches), disable=not verbose): | |
if p_a.sum() > min_nonblank_pixels * window**2: | |
kept_patches.append(i) | |
patches = [patches[i] for i in kept_patches] | |
patches_idx = [patches_idx[i] for i in kept_patches] | |
patches_alpha = [patches_alpha[i] for i in kept_patches] | |
# compute average patch value (i.e. density inside the patch) | |
patches_density = [np.full_like(p_a, (p * p_a).sum() / p_a.sum()) for p, p_a in zip(patches, patches_alpha)] | |
# merge patches | |
pooled_img = merge_patches(patches_density, patches_idx, canvas_shape=img.shape[:2], order='HWC') | |
# apply alpha | |
pooled_img = pooled_img * alpha | |
if normalize: | |
# [0,1]-normalize | |
pooled_img_min = pooled_img.min() | |
pooled_img_max = pooled_img.max() | |
pooled_img = (pooled_img - pooled_img_min) / (pooled_img_max - pooled_img_min) | |
if return_min_max: | |
return pooled_img, pooled_img_min, pooled_img_max | |
return pooled_img | |
def compute_vndvi(image, mask, dilate_rows=True, window_size=360): | |
assert image.dtype == np.uint8, f"Input image must be a numpy array with dtype np.uint8. Got {image.dtype}" | |
assert mask.dtype == np.uint8, f"Input mask must be a numpy array with dtype np.uint8. Got {mask.dtype}" | |
# CHW -> HWC | |
image = image.transpose(1,2,0) | |
# extract channels | |
_image = image.astype(np.float32) / 255 # convert to float32 in [0,1] | |
R, G, B = _image[:,:,0], _image[:,:,1], _image[:,:,2] | |
# to avoid division by 0 due to negative power, we replace 0 with 1 in R and B channels | |
R = np.where(R == 0, 1, R) | |
B = np.where(B == 0, 1, B) | |
# compute vndvi | |
vndvi = 0.5268 * (R**(-0.1294) * G**(0.3389) * B**(-0.3118)) | |
# clip values to [0,1] | |
vndvi = np.clip(vndvi, 0, 1) | |
# compute vndvi rows heatmap | |
#vndvi_rows = np.where(mask == 255, vndvi, np.nan) | |
# compute vndvi interrows heatmap | |
#vndvi_interrows = np.where(mask == 0, vndvi, np.nan) | |
# compute 10th and 90th percentile on whole vineyard vndvi heatmap | |
vndvi_perc10, vndvi_perc90 = np.percentile(vndvi[mask != 1], [10,90]) # mask is 1 for nodata, 0 or 255 for valid pixels | |
# clip values between 10th and 90th percentile | |
vndvi_clipped = np.clip(vndvi, vndvi_perc10, vndvi_perc90) | |
# perform sliding window average pooling to smooth the heatmap | |
# NB: the window takes into account only the rows | |
vndvi_rows_clipped_pooled = sliding_window_avg_pooling( | |
np.where(mask == 255, vndvi_clipped, 0)[...,None], | |
window = int(window_size / 4), | |
granularity = 10, | |
alpha = (mask == 255)[...,None], | |
min_nonblank_pixels = 0.0, | |
) | |
# same, but for interrows | |
vndvi_interrows_clipped_pooled = sliding_window_avg_pooling( | |
np.where(mask == 0, vndvi_clipped, 0)[...,None], | |
window = int(window_size / 4), | |
granularity = 10, | |
alpha = (mask == 0)[...,None], | |
min_nonblank_pixels = 0.0, | |
) | |
# apply dilation to rows mask | |
dilate_rows = True | |
if dilate_rows: | |
dil_factor = int(window_size / 60) | |
mask_rows_dilated = grey_dilation(mask == 255, size=(dil_factor,dil_factor)) | |
vndvi_rows_clipped_pooled_dilated = grey_dilation(vndvi_rows_clipped_pooled, size=(dil_factor,dil_factor,1)) | |
# for visualization purposes, normalize with vndvi_perc10 and | |
# vndvi_perc90 (because we want vndvi_perc10 to be the first color of | |
# the colormap and vndvi_perc90 to be the last) | |
vndvi_rows_clipped_pooled_normalized = (vndvi_rows_clipped_pooled - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10) | |
vndvi_rows_clipped_pooled_dilated_normalized = (vndvi_rows_clipped_pooled_dilated - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10) | |
vndvi_interrows_clipped_pooled_normalized = (vndvi_interrows_clipped_pooled - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10) | |
# for visualization | |
vndvi_rows_img = alpha_composite( | |
image, | |
vndvi_rows_clipped_pooled_dilated_normalized if dilate_rows else vndvi_rows_clipped_pooled_normalized, | |
opacity = 1.0, | |
colormap = 'RdYlGn', | |
alpha_image = np.zeros_like(image[:,:,[0]]), | |
alpha_mask = mask_rows_dilated[...,None] if dilate_rows else (mask == 255)[...,None], | |
) | |
vndvi_interrows_img = alpha_composite( | |
image, | |
vndvi_interrows_clipped_pooled_normalized, | |
opacity = 1.0, | |
colormap = 'RdYlGn', | |
alpha_image = np.zeros_like(image[:,:,[0]]), | |
alpha_mask = (mask == 0)[...,None], | |
) | |
# add colorbar | |
fig_rows, ax = plt.subplots(1, 1, figsize=(10, 10)) | |
divider = make_axes_locatable(ax) | |
cax = divider.append_axes('right', size='5%', pad=0.15) | |
ax.imshow(vndvi_rows_img) | |
fig_rows.colorbar( | |
mappable = mpl.cm.ScalarMappable( | |
norm = mpl.colors.Normalize( | |
vmin = vndvi_perc10, | |
vmax = vndvi_perc90), | |
cmap = 'RdYlGn'), | |
cax = cax, | |
orientation = 'vertical', | |
label = 'vNDVI', | |
shrink = 1) | |
fig_interrows, ax = plt.subplots(1, 1, figsize=(10, 10)) | |
divider = make_axes_locatable(ax) | |
cax = divider.append_axes('right', size='5%', pad=0.15) | |
ax.imshow(vndvi_interrows_img) | |
fig_interrows.colorbar( | |
mappable = mpl.cm.ScalarMappable( | |
norm = mpl.colors.Normalize( | |
vmin = vndvi_perc10, | |
vmax = vndvi_perc90), | |
cmap = 'RdYlGn'), | |
cax = cax, | |
orientation = 'vertical', | |
label = 'vNDVI', | |
shrink = 1) | |
return fig_rows, fig_interrows | |
def compute_vdi(image, mask, window_size=360): | |
assert image.dtype == np.uint8, f"Input image must be a numpy array with dtype np.uint8. Got {image.dtype}" | |
assert mask.dtype == np.uint8, f"Input mask must be a numpy array with dtype np.uint8. Got {mask.dtype}" | |
# CHW -> HWC | |
image = image.transpose(1,2,0) | |
# compute vdi | |
vdi, vdi_min, vdi_max = sliding_window_avg_pooling( | |
(mask == 255)[...,None], | |
window = window_size, | |
granularity = 10, | |
alpha = (mask != 1)[...,None], # mask is 1 for nodata, 0 or 255 for valid pixels | |
min_nonblank_pixels = 0.9, | |
normalize=True, | |
return_min_max=True | |
) | |
# for visualization | |
vdi_img = alpha_composite( | |
image, | |
vdi, | |
opacity = 0.5, | |
colormap = 'jet_r', | |
alpha_image = (mask != 1)[...,None], | |
alpha_mask = (mask != 1)[...,None], | |
) | |
# add colorbar | |
fig, ax = plt.subplots(1, 1, figsize=(10, 10)) | |
divider = make_axes_locatable(ax) | |
cax = divider.append_axes('right', size='5%', pad=0.15) | |
ax.imshow(vdi_img) | |
fig.colorbar( | |
mappable = mpl.cm.ScalarMappable( | |
norm = mpl.colors.Normalize( | |
vmin = vdi_min, | |
vmax = vdi_max), | |
cmap = 'jet_r'), | |
cax = cax, | |
orientation = 'vertical', | |
label = 'VDI', | |
shrink = 1) | |
return fig |