| 
							 | 
						from pprint import pprint | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import cv2 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import albumentations as A | 
					
					
						
						| 
							 | 
						from albumentations.pytorch import ToTensorV2 | 
					
					
						
						| 
							 | 
						from albumentations import ImageOnlyTransform | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						from torch.utils.data import Dataset | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import com_image as ci | 
					
					
						
						| 
							 | 
						import com_plot as cp | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class FixPatchBrightness(ImageOnlyTransform): | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        brightness_target=115, | 
					
					
						
						| 
							 | 
						        brightness_thresholds=(115, 130), | 
					
					
						
						| 
							 | 
						        always_apply: bool = False, | 
					
					
						
						| 
							 | 
						        p: float = 0.5, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        super().__init__(always_apply, p) | 
					
					
						
						| 
							 | 
						        self.brightness_target = brightness_target | 
					
					
						
						| 
							 | 
						        self.brightness_thresholds = brightness_thresholds | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def apply(self, img, brightness_target=None, brightness_thresholds=None, **params): | 
					
					
						
						| 
							 | 
						        brightness_target = ( | 
					
					
						
						| 
							 | 
						            self.brightness_target if brightness_target is None else brightness_target | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        brightness_thresholds = ( | 
					
					
						
						| 
							 | 
						            self.brightness_thresholds | 
					
					
						
						| 
							 | 
						            if brightness_thresholds is None | 
					
					
						
						| 
							 | 
						            else brightness_thresholds | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        r, g, b = cv2.split(img) | 
					
					
						
						| 
							 | 
						        avg_bright = np.sqrt( | 
					
					
						
						| 
							 | 
						            0.241 * np.power(r.astype(float), 2) | 
					
					
						
						| 
							 | 
						            + 0.691 * np.power(g.astype(float), 2) | 
					
					
						
						| 
							 | 
						            + 0.068 * np.power(b.astype(float), 2) | 
					
					
						
						| 
							 | 
						        ).mean() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        tmin, tmax = min(*brightness_thresholds), max(*brightness_thresholds) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if avg_bright < tmin or avg_bright > tmax: | 
					
					
						
						| 
							 | 
						            if avg_bright > brightness_target: | 
					
					
						
						| 
							 | 
						                gamma = brightness_target / avg_bright | 
					
					
						
						| 
							 | 
						                if gamma != 1: | 
					
					
						
						| 
							 | 
						                    inv_gamma = 1.0 / gamma | 
					
					
						
						| 
							 | 
						                    table = np.array( | 
					
					
						
						| 
							 | 
						                        [((i / 255.0) ** inv_gamma) * 255 for i in np.arange(0, 256)] | 
					
					
						
						| 
							 | 
						                    ).astype("uint8") | 
					
					
						
						| 
							 | 
						                    return cv2.LUT(src=img, lut=table) | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    return img | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                return cv2.convertScaleAbs( | 
					
					
						
						| 
							 | 
						                    src=img, | 
					
					
						
						| 
							 | 
						                    alpha=(brightness_target + avg_bright) / (2 * avg_bright), | 
					
					
						
						| 
							 | 
						                    beta=(brightness_target - avg_bright) / 2, | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            return img | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def build_albumentations( | 
					
					
						
						| 
							 | 
						    image_size: int, | 
					
					
						
						| 
							 | 
						    gamma=(60, 180), | 
					
					
						
						| 
							 | 
						    brightness_limit=0.15, | 
					
					
						
						| 
							 | 
						    contrast_limit=0.25, | 
					
					
						
						| 
							 | 
						    crop=None, | 
					
					
						
						| 
							 | 
						    center_crop: int = -1, | 
					
					
						
						| 
							 | 
						    mean=(0.485, 0.456, 0.406), | 
					
					
						
						| 
							 | 
						    std=(0.229, 0.224, 0.225), | 
					
					
						
						| 
							 | 
						    brightness_target=None, | 
					
					
						
						| 
							 | 
						    brightness_thresholds=None, | 
					
					
						
						| 
							 | 
						    affine_transforms={"H": 0.3, "V": 0.3, "R": 0.3, "T": 0.3}, | 
					
					
						
						| 
							 | 
						): | 
					
					
						
						| 
							 | 
						    albs_ = {"resize": [A.Resize(height=image_size, width=image_size, p=1)]} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if brightness_target is not None and brightness_thresholds is not None: | 
					
					
						
						| 
							 | 
						        albs_ = albs_ | { | 
					
					
						
						| 
							 | 
						            "fix_brightness": [ | 
					
					
						
						| 
							 | 
						                FixPatchBrightness( | 
					
					
						
						| 
							 | 
						                    brightness_target=brightness_target, | 
					
					
						
						| 
							 | 
						                    brightness_thresholds=brightness_thresholds, | 
					
					
						
						| 
							 | 
						                    p=1, | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						            ] | 
					
					
						
						| 
							 | 
						        } | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if crop is not None: | 
					
					
						
						| 
							 | 
						        if isinstance(crop, int): | 
					
					
						
						| 
							 | 
						            albs_ = albs_ | { | 
					
					
						
						| 
							 | 
						                "crop_and_pad": [ | 
					
					
						
						| 
							 | 
						                    A.RandomCrop(height=crop, width=crop, p=0.5), | 
					
					
						
						| 
							 | 
						                    A.PadIfNeeded(min_height=image_size, min_width=image_size, p=1), | 
					
					
						
						| 
							 | 
						                ] | 
					
					
						
						| 
							 | 
						            } | 
					
					
						
						| 
							 | 
						        elif isinstance(crop, dict): | 
					
					
						
						| 
							 | 
						            crop_val = crop["value"] | 
					
					
						
						| 
							 | 
						            crop_p = crop["p"] | 
					
					
						
						| 
							 | 
						            albs_ = albs_ | { | 
					
					
						
						| 
							 | 
						                "crop_and_pad": [ | 
					
					
						
						| 
							 | 
						                    A.PadIfNeeded(min_height=crop_val, min_width=crop_val, p=1), | 
					
					
						
						| 
							 | 
						                    A.RandomCrop(height=crop_val, width=crop_val, p=crop_p), | 
					
					
						
						| 
							 | 
						                    A.PadIfNeeded(min_height=image_size, min_width=image_size, p=1), | 
					
					
						
						| 
							 | 
						                ] | 
					
					
						
						| 
							 | 
						            } | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if center_crop > -1: | 
					
					
						
						| 
							 | 
						        albs_ = albs_ | { | 
					
					
						
						| 
							 | 
						            "center_crop": [ | 
					
					
						
						| 
							 | 
						                A.PadIfNeeded(min_height=center_crop, min_width=center_crop, p=1), | 
					
					
						
						| 
							 | 
						                A.CenterCrop(height=center_crop, width=center_crop, p=1), | 
					
					
						
						| 
							 | 
						            ] | 
					
					
						
						| 
							 | 
						        } | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    affine = [] | 
					
					
						
						| 
							 | 
						    for k, v in affine_transforms.items(): | 
					
					
						
						| 
							 | 
						        if k == "H": | 
					
					
						
						| 
							 | 
						            affine.append(A.HorizontalFlip(p=v)) | 
					
					
						
						| 
							 | 
						        elif k == "V": | 
					
					
						
						| 
							 | 
						            affine.append(A.VerticalFlip(p=v)) | 
					
					
						
						| 
							 | 
						        elif k == "R": | 
					
					
						
						| 
							 | 
						            affine.append(A.RandomRotate90(p=v)) | 
					
					
						
						| 
							 | 
						        elif k == "T": | 
					
					
						
						| 
							 | 
						            affine.append(A.Transpose(p=v)) | 
					
					
						
						| 
							 | 
						    albs_ = albs_ | {"affine": affine} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    color = [] | 
					
					
						
						| 
							 | 
						    if brightness_limit is not None and contrast_limit is not None: | 
					
					
						
						| 
							 | 
						        color.append( | 
					
					
						
						| 
							 | 
						            A.RandomBrightnessContrast( | 
					
					
						
						| 
							 | 
						                brightness_limit=brightness_limit, | 
					
					
						
						| 
							 | 
						                contrast_limit=contrast_limit, | 
					
					
						
						| 
							 | 
						                p=0.5, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						    if gamma is not None: | 
					
					
						
						| 
							 | 
						        color.append(A.RandomGamma(gamma_limit=gamma, p=0.5)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    albs_ = albs_ | {"color": color} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return albs_ | { | 
					
					
						
						| 
							 | 
						        "to_tensor": [A.Normalize(mean=mean, std=std, p=1), ToTensorV2()], | 
					
					
						
						| 
							 | 
						        "un_normalize": [ | 
					
					
						
						| 
							 | 
						            A.Normalize( | 
					
					
						
						| 
							 | 
						                mean=[-m / s for m, s in zip(mean, std)], | 
					
					
						
						| 
							 | 
						                std=[1.0 / s for s in std], | 
					
					
						
						| 
							 | 
						                always_apply=True, | 
					
					
						
						| 
							 | 
						                max_pixel_value=1.0, | 
					
					
						
						| 
							 | 
						            ), | 
					
					
						
						| 
							 | 
						        ], | 
					
					
						
						| 
							 | 
						    } | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_augmentations( | 
					
					
						
						| 
							 | 
						    image_size: int = 224, | 
					
					
						
						| 
							 | 
						    gamma=(60, 180), | 
					
					
						
						| 
							 | 
						    brightness_limit=0.15, | 
					
					
						
						| 
							 | 
						    contrast_limit=0.25, | 
					
					
						
						| 
							 | 
						    crop=180, | 
					
					
						
						| 
							 | 
						    center_crop: int = -1, | 
					
					
						
						| 
							 | 
						    kinds: list = ["resize", "to_tensor"], | 
					
					
						
						| 
							 | 
						    mean=(0.485, 0.456, 0.406), | 
					
					
						
						| 
							 | 
						    std=(0.229, 0.224, 0.225), | 
					
					
						
						| 
							 | 
						    brightness_target=None, | 
					
					
						
						| 
							 | 
						    brightness_thresholds=None, | 
					
					
						
						| 
							 | 
						    affine_transforms={"H": 0.3, "V": 0.3, "R": 0.3, "T": 0.3}, | 
					
					
						
						| 
							 | 
						): | 
					
					
						
						| 
							 | 
						    if "train" in kinds: | 
					
					
						
						| 
							 | 
						        kinds.insert(kinds.index("train"), "affine") | 
					
					
						
						| 
							 | 
						        kinds.insert(kinds.index("train"), "color") | 
					
					
						
						| 
							 | 
						        kinds.remove("train") | 
					
					
						
						| 
							 | 
						    td_ = build_albumentations( | 
					
					
						
						| 
							 | 
						        image_size := image_size, | 
					
					
						
						| 
							 | 
						        gamma=gamma, | 
					
					
						
						| 
							 | 
						        brightness_limit=brightness_limit, | 
					
					
						
						| 
							 | 
						        contrast_limit=contrast_limit, | 
					
					
						
						| 
							 | 
						        crop=crop, | 
					
					
						
						| 
							 | 
						        center_crop=center_crop, | 
					
					
						
						| 
							 | 
						        mean=mean, | 
					
					
						
						| 
							 | 
						        std=std, | 
					
					
						
						| 
							 | 
						        brightness_target=brightness_target, | 
					
					
						
						| 
							 | 
						        brightness_thresholds=brightness_thresholds, | 
					
					
						
						| 
							 | 
						        affine_transforms=affine_transforms, | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    augs = [] | 
					
					
						
						| 
							 | 
						    for k in kinds: | 
					
					
						
						| 
							 | 
						        if k: | 
					
					
						
						| 
							 | 
						            augs += td_[k]   | 
					
					
						
						| 
							 | 
						    return A.Compose(augs) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class MlcPatches(Dataset): | 
					
					
						
						| 
							 | 
						    def __init__(self, dataframe, transform, path_to_images) -> None: | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.dataframe = dataframe | 
					
					
						
						| 
							 | 
						        self.transform = transform | 
					
					
						
						| 
							 | 
						        self.path_to_images = path_to_images | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __len__(self): | 
					
					
						
						| 
							 | 
						        return self.dataframe.shape[0] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __getitem__(self, index): | 
					
					
						
						| 
							 | 
						        img = self.transform(image=self.get_image(index=index))["image"] | 
					
					
						
						| 
							 | 
						        return {"image": img, "labels": torch.tensor([1])} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def get_image(self, index): | 
					
					
						
						| 
							 | 
						        return ci.load_image( | 
					
					
						
						| 
							 | 
						            file_name=self.dataframe.file_name.to_list()[index], | 
					
					
						
						| 
							 | 
						            path_to_images=self.path_to_images, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def test_augmentations( | 
					
					
						
						| 
							 | 
						    df, | 
					
					
						
						| 
							 | 
						    image_size, | 
					
					
						
						| 
							 | 
						    path_to_images, | 
					
					
						
						| 
							 | 
						    columns: list = [], | 
					
					
						
						| 
							 | 
						    kinds: list = ["resize", "to_tensor"], | 
					
					
						
						| 
							 | 
						    rows: int = 2, | 
					
					
						
						| 
							 | 
						    cols: int = 4, | 
					
					
						
						| 
							 | 
						    **aug_params, | 
					
					
						
						| 
							 | 
						): | 
					
					
						
						| 
							 | 
						    sample = df.sample(n=1) | 
					
					
						
						| 
							 | 
						    src_dataset = MlcPatches( | 
					
					
						
						| 
							 | 
						        dataframe=sample, | 
					
					
						
						| 
							 | 
						        transform=get_augmentations( | 
					
					
						
						| 
							 | 
						            image_size=image_size, kinds=["resize", "to_tensor"], **aug_params | 
					
					
						
						| 
							 | 
						        ), | 
					
					
						
						| 
							 | 
						        path_to_images=path_to_images, | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    test_dataset = MlcPatches( | 
					
					
						
						| 
							 | 
						        dataframe=sample, | 
					
					
						
						| 
							 | 
						        transform=get_augmentations(image_size=image_size, kinds=kinds, **aug_params), | 
					
					
						
						| 
							 | 
						        path_to_images=path_to_images, | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    pprint(sample[[c for c in ["file_name"] + columns if c in sample]]) | 
					
					
						
						| 
							 | 
						    cp.tensor_image_to_grid( | 
					
					
						
						| 
							 | 
						        images=[(src_dataset[0]["image"], "source")] | 
					
					
						
						| 
							 | 
						        + [(test_dataset[0]["image"], "augmented") for i in range(rows * cols)], | 
					
					
						
						| 
							 | 
						        transform=get_augmentations( | 
					
					
						
						| 
							 | 
						            image_size=image_size, kinds=(["un_normalize"]), **aug_params | 
					
					
						
						| 
							 | 
						        ), | 
					
					
						
						| 
							 | 
						        row_count=rows, | 
					
					
						
						| 
							 | 
						        col_count=cols, | 
					
					
						
						| 
							 | 
						        figsize=(cols * 4, rows * 4), | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 |