from pprint import pprint import numpy as np import cv2 import albumentations as A from albumentations.pytorch import ToTensorV2 from albumentations import ImageOnlyTransform import torch from torch.utils.data import Dataset import com_image as ci import com_plot as cp class FixPatchBrightness(ImageOnlyTransform): def __init__( self, brightness_target=115, brightness_thresholds=(115, 130), always_apply: bool = False, p: float = 0.5, ): super().__init__(always_apply, p) self.brightness_target = brightness_target self.brightness_thresholds = brightness_thresholds def apply(self, img, brightness_target=None, brightness_thresholds=None, **params): brightness_target = ( self.brightness_target if brightness_target is None else brightness_target ) brightness_thresholds = ( self.brightness_thresholds if brightness_thresholds is None else brightness_thresholds ) r, g, b = cv2.split(img) avg_bright = np.sqrt( 0.241 * np.power(r.astype(float), 2) + 0.691 * np.power(g.astype(float), 2) + 0.068 * np.power(b.astype(float), 2) ).mean() tmin, tmax = min(*brightness_thresholds), max(*brightness_thresholds) if avg_bright < tmin or avg_bright > tmax: if avg_bright > brightness_target: gamma = brightness_target / avg_bright if gamma != 1: inv_gamma = 1.0 / gamma table = np.array( [((i / 255.0) ** inv_gamma) * 255 for i in np.arange(0, 256)] ).astype("uint8") return cv2.LUT(src=img, lut=table) else: return img else: return cv2.convertScaleAbs( src=img, alpha=(brightness_target + avg_bright) / (2 * avg_bright), beta=(brightness_target - avg_bright) / 2, ) else: return img def build_albumentations( image_size: int, gamma=(60, 180), brightness_limit=0.15, contrast_limit=0.25, crop=None, center_crop: int = -1, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), brightness_target=None, brightness_thresholds=None, affine_transforms={"H": 0.3, "V": 0.3, "R": 0.3, "T": 0.3}, ): albs_ = {"resize": [A.Resize(height=image_size, width=image_size, p=1)]} if brightness_target is not None and brightness_thresholds is not None: albs_ = albs_ | { "fix_brightness": [ FixPatchBrightness( brightness_target=brightness_target, brightness_thresholds=brightness_thresholds, p=1, ) ] } if crop is not None: if isinstance(crop, int): albs_ = albs_ | { "crop_and_pad": [ A.RandomCrop(height=crop, width=crop, p=0.5), A.PadIfNeeded(min_height=image_size, min_width=image_size, p=1), ] } elif isinstance(crop, dict): crop_val = crop["value"] crop_p = crop["p"] albs_ = albs_ | { "crop_and_pad": [ A.PadIfNeeded(min_height=crop_val, min_width=crop_val, p=1), A.RandomCrop(height=crop_val, width=crop_val, p=crop_p), A.PadIfNeeded(min_height=image_size, min_width=image_size, p=1), ] } if center_crop > -1: albs_ = albs_ | { "center_crop": [ A.PadIfNeeded(min_height=center_crop, min_width=center_crop, p=1), A.CenterCrop(height=center_crop, width=center_crop, p=1), ] } affine = [] for k, v in affine_transforms.items(): if k == "H": affine.append(A.HorizontalFlip(p=v)) elif k == "V": affine.append(A.VerticalFlip(p=v)) elif k == "R": affine.append(A.RandomRotate90(p=v)) elif k == "T": affine.append(A.Transpose(p=v)) albs_ = albs_ | {"affine": affine} color = [] if brightness_limit is not None and contrast_limit is not None: color.append( A.RandomBrightnessContrast( brightness_limit=brightness_limit, contrast_limit=contrast_limit, p=0.5, ) ) if gamma is not None: color.append(A.RandomGamma(gamma_limit=gamma, p=0.5)) albs_ = albs_ | {"color": color} return albs_ | { "to_tensor": [A.Normalize(mean=mean, std=std, p=1), ToTensorV2()], "un_normalize": [ A.Normalize( mean=[-m / s for m, s in zip(mean, std)], std=[1.0 / s for s in std], always_apply=True, max_pixel_value=1.0, ), ], } def get_augmentations( image_size: int = 224, gamma=(60, 180), brightness_limit=0.15, contrast_limit=0.25, crop=180, center_crop: int = -1, kinds: list = ["resize", "to_tensor"], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), brightness_target=None, brightness_thresholds=None, affine_transforms={"H": 0.3, "V": 0.3, "R": 0.3, "T": 0.3}, ): if "train" in kinds: kinds.insert(kinds.index("train"), "affine") kinds.insert(kinds.index("train"), "color") kinds.remove("train") td_ = build_albumentations( image_size := image_size, gamma=gamma, brightness_limit=brightness_limit, contrast_limit=contrast_limit, crop=crop, center_crop=center_crop, mean=mean, std=std, brightness_target=brightness_target, brightness_thresholds=brightness_thresholds, affine_transforms=affine_transforms, ) augs = [] for k in kinds: if k: augs += td_[k] # .append(*[a for a in td_[k]]) return A.Compose(augs) class MlcPatches(Dataset): def __init__(self, dataframe, transform, path_to_images) -> None: super().__init__() self.dataframe = dataframe self.transform = transform self.path_to_images = path_to_images def __len__(self): return self.dataframe.shape[0] def __getitem__(self, index): img = self.transform(image=self.get_image(index=index))["image"] return {"image": img, "labels": torch.tensor([1])} def get_image(self, index): return ci.load_image( file_name=self.dataframe.file_name.to_list()[index], path_to_images=self.path_to_images, ) def test_augmentations( df, image_size, path_to_images, columns: list = [], kinds: list = ["resize", "to_tensor"], rows: int = 2, cols: int = 4, **aug_params, ): sample = df.sample(n=1) src_dataset = MlcPatches( dataframe=sample, transform=get_augmentations( image_size=image_size, kinds=["resize", "to_tensor"], **aug_params ), path_to_images=path_to_images, ) test_dataset = MlcPatches( dataframe=sample, transform=get_augmentations(image_size=image_size, kinds=kinds, **aug_params), path_to_images=path_to_images, ) pprint(sample[[c for c in ["file_name"] + columns if c in sample]]) cp.tensor_image_to_grid( images=[(src_dataset[0]["image"], "source")] + [(test_dataset[0]["image"], "augmented") for i in range(rows * cols)], transform=get_augmentations( image_size=image_size, kinds=(["un_normalize"]), **aug_params ), row_count=rows, col_count=cols, figsize=(cols * 4, rows * 4), )