|
from pprint import pprint |
|
|
|
import numpy as np |
|
|
|
import cv2 |
|
|
|
import albumentations as A |
|
from albumentations.pytorch import ToTensorV2 |
|
from albumentations import ImageOnlyTransform |
|
|
|
import torch |
|
from torch.utils.data import Dataset |
|
|
|
import com_image as ci |
|
import com_plot as cp |
|
|
|
|
|
class FixPatchBrightness(ImageOnlyTransform): |
|
def __init__( |
|
self, |
|
brightness_target=115, |
|
brightness_thresholds=(115, 130), |
|
always_apply: bool = False, |
|
p: float = 0.5, |
|
): |
|
super().__init__(always_apply, p) |
|
self.brightness_target = brightness_target |
|
self.brightness_thresholds = brightness_thresholds |
|
|
|
def apply(self, img, brightness_target=None, brightness_thresholds=None, **params): |
|
brightness_target = ( |
|
self.brightness_target if brightness_target is None else brightness_target |
|
) |
|
brightness_thresholds = ( |
|
self.brightness_thresholds |
|
if brightness_thresholds is None |
|
else brightness_thresholds |
|
) |
|
|
|
r, g, b = cv2.split(img) |
|
avg_bright = np.sqrt( |
|
0.241 * np.power(r.astype(float), 2) |
|
+ 0.691 * np.power(g.astype(float), 2) |
|
+ 0.068 * np.power(b.astype(float), 2) |
|
).mean() |
|
|
|
tmin, tmax = min(*brightness_thresholds), max(*brightness_thresholds) |
|
|
|
if avg_bright < tmin or avg_bright > tmax: |
|
if avg_bright > brightness_target: |
|
gamma = brightness_target / avg_bright |
|
if gamma != 1: |
|
inv_gamma = 1.0 / gamma |
|
table = np.array( |
|
[((i / 255.0) ** inv_gamma) * 255 for i in np.arange(0, 256)] |
|
).astype("uint8") |
|
return cv2.LUT(src=img, lut=table) |
|
else: |
|
return img |
|
else: |
|
return cv2.convertScaleAbs( |
|
src=img, |
|
alpha=(brightness_target + avg_bright) / (2 * avg_bright), |
|
beta=(brightness_target - avg_bright) / 2, |
|
) |
|
else: |
|
return img |
|
|
|
|
|
def build_albumentations( |
|
image_size: int, |
|
gamma=(60, 180), |
|
brightness_limit=0.15, |
|
contrast_limit=0.25, |
|
crop=None, |
|
center_crop: int = -1, |
|
mean=(0.485, 0.456, 0.406), |
|
std=(0.229, 0.224, 0.225), |
|
brightness_target=None, |
|
brightness_thresholds=None, |
|
affine_transforms={"H": 0.3, "V": 0.3, "R": 0.3, "T": 0.3}, |
|
): |
|
albs_ = {"resize": [A.Resize(height=image_size, width=image_size, p=1)]} |
|
|
|
if brightness_target is not None and brightness_thresholds is not None: |
|
albs_ = albs_ | { |
|
"fix_brightness": [ |
|
FixPatchBrightness( |
|
brightness_target=brightness_target, |
|
brightness_thresholds=brightness_thresholds, |
|
p=1, |
|
) |
|
] |
|
} |
|
|
|
if crop is not None: |
|
if isinstance(crop, int): |
|
albs_ = albs_ | { |
|
"crop_and_pad": [ |
|
A.RandomCrop(height=crop, width=crop, p=0.5), |
|
A.PadIfNeeded(min_height=image_size, min_width=image_size, p=1), |
|
] |
|
} |
|
elif isinstance(crop, dict): |
|
crop_val = crop["value"] |
|
crop_p = crop["p"] |
|
albs_ = albs_ | { |
|
"crop_and_pad": [ |
|
A.PadIfNeeded(min_height=crop_val, min_width=crop_val, p=1), |
|
A.RandomCrop(height=crop_val, width=crop_val, p=crop_p), |
|
A.PadIfNeeded(min_height=image_size, min_width=image_size, p=1), |
|
] |
|
} |
|
|
|
if center_crop > -1: |
|
albs_ = albs_ | { |
|
"center_crop": [ |
|
A.PadIfNeeded(min_height=center_crop, min_width=center_crop, p=1), |
|
A.CenterCrop(height=center_crop, width=center_crop, p=1), |
|
] |
|
} |
|
|
|
affine = [] |
|
for k, v in affine_transforms.items(): |
|
if k == "H": |
|
affine.append(A.HorizontalFlip(p=v)) |
|
elif k == "V": |
|
affine.append(A.VerticalFlip(p=v)) |
|
elif k == "R": |
|
affine.append(A.RandomRotate90(p=v)) |
|
elif k == "T": |
|
affine.append(A.Transpose(p=v)) |
|
albs_ = albs_ | {"affine": affine} |
|
|
|
color = [] |
|
if brightness_limit is not None and contrast_limit is not None: |
|
color.append( |
|
A.RandomBrightnessContrast( |
|
brightness_limit=brightness_limit, |
|
contrast_limit=contrast_limit, |
|
p=0.5, |
|
) |
|
) |
|
if gamma is not None: |
|
color.append(A.RandomGamma(gamma_limit=gamma, p=0.5)) |
|
|
|
albs_ = albs_ | {"color": color} |
|
|
|
return albs_ | { |
|
"to_tensor": [A.Normalize(mean=mean, std=std, p=1), ToTensorV2()], |
|
"un_normalize": [ |
|
A.Normalize( |
|
mean=[-m / s for m, s in zip(mean, std)], |
|
std=[1.0 / s for s in std], |
|
always_apply=True, |
|
max_pixel_value=1.0, |
|
), |
|
], |
|
} |
|
|
|
|
|
def get_augmentations( |
|
image_size: int = 224, |
|
gamma=(60, 180), |
|
brightness_limit=0.15, |
|
contrast_limit=0.25, |
|
crop=180, |
|
center_crop: int = -1, |
|
kinds: list = ["resize", "to_tensor"], |
|
mean=(0.485, 0.456, 0.406), |
|
std=(0.229, 0.224, 0.225), |
|
brightness_target=None, |
|
brightness_thresholds=None, |
|
affine_transforms={"H": 0.3, "V": 0.3, "R": 0.3, "T": 0.3}, |
|
): |
|
if "train" in kinds: |
|
kinds.insert(kinds.index("train"), "affine") |
|
kinds.insert(kinds.index("train"), "color") |
|
kinds.remove("train") |
|
td_ = build_albumentations( |
|
image_size := image_size, |
|
gamma=gamma, |
|
brightness_limit=brightness_limit, |
|
contrast_limit=contrast_limit, |
|
crop=crop, |
|
center_crop=center_crop, |
|
mean=mean, |
|
std=std, |
|
brightness_target=brightness_target, |
|
brightness_thresholds=brightness_thresholds, |
|
affine_transforms=affine_transforms, |
|
) |
|
augs = [] |
|
for k in kinds: |
|
if k: |
|
augs += td_[k] |
|
return A.Compose(augs) |
|
|
|
|
|
class MlcPatches(Dataset): |
|
def __init__(self, dataframe, transform, path_to_images) -> None: |
|
super().__init__() |
|
self.dataframe = dataframe |
|
self.transform = transform |
|
self.path_to_images = path_to_images |
|
|
|
def __len__(self): |
|
return self.dataframe.shape[0] |
|
|
|
def __getitem__(self, index): |
|
img = self.transform(image=self.get_image(index=index))["image"] |
|
return {"image": img, "labels": torch.tensor([1])} |
|
|
|
def get_image(self, index): |
|
return ci.load_image( |
|
file_name=self.dataframe.file_name.to_list()[index], |
|
path_to_images=self.path_to_images, |
|
) |
|
|
|
|
|
def test_augmentations( |
|
df, |
|
image_size, |
|
path_to_images, |
|
columns: list = [], |
|
kinds: list = ["resize", "to_tensor"], |
|
rows: int = 2, |
|
cols: int = 4, |
|
**aug_params, |
|
): |
|
sample = df.sample(n=1) |
|
src_dataset = MlcPatches( |
|
dataframe=sample, |
|
transform=get_augmentations( |
|
image_size=image_size, kinds=["resize", "to_tensor"], **aug_params |
|
), |
|
path_to_images=path_to_images, |
|
) |
|
|
|
test_dataset = MlcPatches( |
|
dataframe=sample, |
|
transform=get_augmentations(image_size=image_size, kinds=kinds, **aug_params), |
|
path_to_images=path_to_images, |
|
) |
|
pprint(sample[[c for c in ["file_name"] + columns if c in sample]]) |
|
cp.tensor_image_to_grid( |
|
images=[(src_dataset[0]["image"], "source")] |
|
+ [(test_dataset[0]["image"], "augmented") for i in range(rows * cols)], |
|
transform=get_augmentations( |
|
image_size=image_size, kinds=(["un_normalize"]), **aug_params |
|
), |
|
row_count=rows, |
|
col_count=cols, |
|
figsize=(cols * 4, rows * 4), |
|
) |
|
|