from PIL import Image from typing import Tuple import numpy as np import matplotlib.pyplot as plt from scipy import ndimage import torch from torchvision.transforms import functional as tvf from pathlib import Path def sliced_mean(x, slice_size): cs_y = np.cumsum(x, axis=0) cs_y = np.concatenate((np.zeros((1, cs_y.shape[1]), dtype=cs_y.dtype), cs_y), axis=0) slices_y = (cs_y[slice_size:] - cs_y[:-slice_size])/slice_size cs_xy = np.cumsum(slices_y, axis=1) cs_xy = np.concatenate((np.zeros((cs_xy.shape[0], 1), dtype=cs_xy.dtype), cs_xy), axis=1) slices_xy = (cs_xy[:,slice_size:] - cs_xy[:,:-slice_size])/slice_size return slices_xy def sliced_var(x, slice_size): x = x.astype('float64') return sliced_mean(x**2, slice_size) - sliced_mean(x, slice_size)**2 def calculate_local_variance(img, var_window): """return local variance map with the same size as input image""" var = sliced_var(img, var_window) left_pad = var_window // 2 -1 right_pad = var_window -1 - left_pad var_padded = np.pad( var, pad_width=( (left_pad,right_pad), (left_pad,right_pad) )) return var_padded def get_crop_batch(img: np.ndarray, mask: np.ndarray, crop_size=96, crop_scales=np.geomspace(0.5, 2, 7), samples_per_scale=32, use_variance_threshold=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Generate a batch of cropped images from an input image and corresponding mask, at various scales and rotations. Parameters ---------- img : np.ndarray The input image from which crops are generated. mask : np.ndarray The binary mask indicating the region of interest in the image. crop_size : int, optional The size of the square crop. crop_scales : np.ndarray, optional An array of scale factors to apply to the crop size. samples_per_scale : int, optional Number of samples to generate per scale factor. use_variance_threshold : bool, optional Flag to use variance thresholding for selecting crop locations. Returns ------- Tuple[torch.Tensor, torch.Tensor, torch.Tensor] A tuple containing the tensor of crops, their rotation angles, and scale factors. """ # pad pad_size = int(np.ceil(0.5*crop_size*max(crop_scales)*(np.sqrt(2)-1))) img_padded = np.pad(img, pad_size) mask_padded = np.pad(mask, pad_size) # distance map distance_map_padded = ndimage.distance_transform_edt(mask_padded) # TODO: adjust scales and samples_per_scale if use_variance_threshold: variance_window = min(crop_size//2, min(img.shape)) variance_map_padded = np.pad(calculate_local_variance(img, variance_window), pad_size) variance_median = np.ma.median(np.ma.masked_where(distance_map_padded<0.5*variance_window, variance_map_padded)) variance_mask = variance_map_padded >= variance_median else: variance_mask = np.ones_like(mask_padded) # initilize output crops_granum = [] angles_granum = [] scales_granum = [] # loop over scales for scale in crop_scales: half_crop_size_scaled = int(np.floor(scale*0.5*crop_size)) # half of crop size after scaling crop_pad = int(np.ceil((np.sqrt(2) - 1)*half_crop_size_scaled)) # pad added in order to allow rotation half_crop_size_external = half_crop_size_scaled + crop_pad # size of "external crop" which will be rotated possible_indices = np.stack(np.where(variance_mask & (distance_map_padded >= 2*half_crop_size_scaled)), axis=1) if len(possible_indices) == 0: continue chosen_indices = np.random.choice(np.arange(len(possible_indices)), min(len(possible_indices), samples_per_scale), replace=False) crops = [ img_padded[y-half_crop_size_external:y+half_crop_size_external, x-half_crop_size_external:x+half_crop_size_external] for y, x in possible_indices[chosen_indices] ] # rotate rotation_angles = np.random.rand(len(crops))*180 - 90 crops = [ ndimage.rotate(crop, angle, reshape=False)[crop_pad:-crop_pad,crop_pad:-crop_pad] for crop, angle in zip(crops, rotation_angles) ] # add to output crops_granum.append(tvf.resize(torch.tensor(np.array(crops)), (crop_size,crop_size),antialias=True)) # resize crops to crop_size angles_granum.extend(rotation_angles.tolist()) scales_granum.extend([scale]*len(crops)) if len(angles_granum) == 0: return [], [], [] crops_granum = torch.concat(crops_granum) angles_granum = torch.tensor(angles_granum, dtype=torch.float) scales_granum = torch.tensor(scales_granum, dtype=torch.float) return crops_granum, angles_granum, scales_granum def get_crop_batch_from_path(img_path, mask_path=None, use_variance_threshold=False): """ Load an image and its mask from file paths and generate a batch of cropped images. Parameters ---------- img_path : str Path to the input image. mask_path : str, optional Path to the binary mask image. If None, assumes mask path by replacing image extension with '.npy'. use_variance_threshold : bool, optional Flag to use variance thresholding for selecting crop locations. Returns ------- Tuple[torch.Tensor, torch.Tensor, torch.Tensor] A tuple containing the tensor of crops, their rotation angles, and scale factors, obtained from the specified image path. """ if mask_path is None: mask_path = str(Path(img_path).with_suffix('.npy')) mask = np.load(mask_path) img = np.array(Image.open(img_path))[:,:,0] return get_crop_batch(img, mask, use_variance_threshold=use_variance_threshold)