|
""" AutoAugment, RandAugment, AugMix, and 3-Augment for PyTorch |
|
|
|
This code implements the searched ImageNet policies with various tweaks and improvements and |
|
does not include any of the search code. |
|
|
|
AA and RA Implementation adapted from: |
|
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py |
|
|
|
AugMix adapted from: |
|
https://github.com/google-research/augmix |
|
|
|
3-Augment based on: https://github.com/facebookresearch/deit/blob/main/README_revenge.md |
|
|
|
Papers: |
|
AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501 |
|
Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172 |
|
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 |
|
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781 |
|
3-Augment: DeiT III: Revenge of the ViT - https://arxiv.org/abs/2204.07118 |
|
|
|
Hacked together by / Copyright 2019, Ross Wightman |
|
""" |
|
import random |
|
import math |
|
import re |
|
from functools import partial |
|
from typing import Dict, List, Optional, Union |
|
|
|
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter |
|
import PIL |
|
import numpy as np |
|
|
|
|
|
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) |
|
|
|
_FILL = (128, 128, 128) |
|
|
|
_LEVEL_DENOM = 10. |
|
|
|
_HPARAMS_DEFAULT = dict( |
|
translate_const=250, |
|
img_mean=_FILL, |
|
) |
|
|
|
if hasattr(Image, "Resampling"): |
|
_RANDOM_INTERPOLATION = (Image.Resampling.BILINEAR, Image.Resampling.BICUBIC) |
|
_DEFAULT_INTERPOLATION = Image.Resampling.BICUBIC |
|
else: |
|
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) |
|
_DEFAULT_INTERPOLATION = Image.BICUBIC |
|
|
|
|
|
def _interpolation(kwargs): |
|
interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION) |
|
if isinstance(interpolation, (list, tuple)): |
|
return random.choice(interpolation) |
|
return interpolation |
|
|
|
|
|
def _check_args_tf(kwargs): |
|
if 'fillcolor' in kwargs and _PIL_VER < (5, 0): |
|
kwargs.pop('fillcolor') |
|
kwargs['resample'] = _interpolation(kwargs) |
|
|
|
|
|
def shear_x(img, factor, **kwargs): |
|
_check_args_tf(kwargs) |
|
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) |
|
|
|
|
|
def shear_y(img, factor, **kwargs): |
|
_check_args_tf(kwargs) |
|
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) |
|
|
|
|
|
def translate_x_rel(img, pct, **kwargs): |
|
pixels = pct * img.size[0] |
|
_check_args_tf(kwargs) |
|
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) |
|
|
|
|
|
def translate_y_rel(img, pct, **kwargs): |
|
pixels = pct * img.size[1] |
|
_check_args_tf(kwargs) |
|
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) |
|
|
|
|
|
def translate_x_abs(img, pixels, **kwargs): |
|
_check_args_tf(kwargs) |
|
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) |
|
|
|
|
|
def translate_y_abs(img, pixels, **kwargs): |
|
_check_args_tf(kwargs) |
|
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) |
|
|
|
|
|
def rotate(img, degrees, **kwargs): |
|
_check_args_tf(kwargs) |
|
if _PIL_VER >= (5, 2): |
|
return img.rotate(degrees, **kwargs) |
|
if _PIL_VER >= (5, 0): |
|
w, h = img.size |
|
post_trans = (0, 0) |
|
rotn_center = (w / 2.0, h / 2.0) |
|
angle = -math.radians(degrees) |
|
matrix = [ |
|
round(math.cos(angle), 15), |
|
round(math.sin(angle), 15), |
|
0.0, |
|
round(-math.sin(angle), 15), |
|
round(math.cos(angle), 15), |
|
0.0, |
|
] |
|
|
|
def transform(x, y, matrix): |
|
(a, b, c, d, e, f) = matrix |
|
return a * x + b * y + c, d * x + e * y + f |
|
|
|
matrix[2], matrix[5] = transform( |
|
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix |
|
) |
|
matrix[2] += rotn_center[0] |
|
matrix[5] += rotn_center[1] |
|
return img.transform(img.size, Image.AFFINE, matrix, **kwargs) |
|
return img.rotate(degrees, resample=kwargs['resample']) |
|
|
|
|
|
def auto_contrast(img, **__): |
|
return ImageOps.autocontrast(img) |
|
|
|
|
|
def invert(img, **__): |
|
return ImageOps.invert(img) |
|
|
|
|
|
def equalize(img, **__): |
|
return ImageOps.equalize(img) |
|
|
|
|
|
def solarize(img, thresh, **__): |
|
return ImageOps.solarize(img, thresh) |
|
|
|
|
|
def solarize_add(img, add, thresh=128, **__): |
|
lut = [] |
|
for i in range(256): |
|
if i < thresh: |
|
lut.append(min(255, i + add)) |
|
else: |
|
lut.append(i) |
|
|
|
if img.mode in ("L", "RGB"): |
|
if img.mode == "RGB" and len(lut) == 256: |
|
lut = lut + lut + lut |
|
return img.point(lut) |
|
|
|
return img |
|
|
|
|
|
def posterize(img, bits_to_keep, **__): |
|
if bits_to_keep >= 8: |
|
return img |
|
return ImageOps.posterize(img, bits_to_keep) |
|
|
|
|
|
def contrast(img, factor, **__): |
|
return ImageEnhance.Contrast(img).enhance(factor) |
|
|
|
|
|
def color(img, factor, **__): |
|
return ImageEnhance.Color(img).enhance(factor) |
|
|
|
|
|
def brightness(img, factor, **__): |
|
return ImageEnhance.Brightness(img).enhance(factor) |
|
|
|
|
|
def sharpness(img, factor, **__): |
|
return ImageEnhance.Sharpness(img).enhance(factor) |
|
|
|
|
|
def gaussian_blur(img, factor, **__): |
|
img = img.filter(ImageFilter.GaussianBlur(radius=factor)) |
|
return img |
|
|
|
|
|
def gaussian_blur_rand(img, factor, **__): |
|
radius_min = 0.1 |
|
radius_max = 2.0 |
|
img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor))) |
|
return img |
|
|
|
|
|
def desaturate(img, factor, **_): |
|
factor = min(1., max(0., 1. - factor)) |
|
|
|
return ImageEnhance.Color(img).enhance(factor) |
|
|
|
|
|
def _randomly_negate(v): |
|
"""With 50% prob, negate the value""" |
|
return -v if random.random() > 0.5 else v |
|
|
|
|
|
def _rotate_level_to_arg(level, _hparams): |
|
|
|
level = (level / _LEVEL_DENOM) * 30. |
|
level = _randomly_negate(level) |
|
return level, |
|
|
|
|
|
def _enhance_level_to_arg(level, _hparams): |
|
|
|
return (level / _LEVEL_DENOM) * 1.8 + 0.1, |
|
|
|
|
|
def _enhance_increasing_level_to_arg(level, _hparams): |
|
|
|
|
|
level = (level / _LEVEL_DENOM) * .9 |
|
level = max(0.1, 1.0 + _randomly_negate(level)) |
|
return level, |
|
|
|
|
|
def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True): |
|
level = (level / _LEVEL_DENOM) |
|
level = min_val + (max_val - min_val) * level |
|
if clamp: |
|
level = max(min_val, min(max_val, level)) |
|
return level, |
|
|
|
|
|
def _shear_level_to_arg(level, _hparams): |
|
|
|
level = (level / _LEVEL_DENOM) * 0.3 |
|
level = _randomly_negate(level) |
|
return level, |
|
|
|
|
|
def _translate_abs_level_to_arg(level, hparams): |
|
translate_const = hparams['translate_const'] |
|
level = (level / _LEVEL_DENOM) * float(translate_const) |
|
level = _randomly_negate(level) |
|
return level, |
|
|
|
|
|
def _translate_rel_level_to_arg(level, hparams): |
|
|
|
translate_pct = hparams.get('translate_pct', 0.45) |
|
level = (level / _LEVEL_DENOM) * translate_pct |
|
level = _randomly_negate(level) |
|
return level, |
|
|
|
|
|
def _posterize_level_to_arg(level, _hparams): |
|
|
|
|
|
|
|
return int((level / _LEVEL_DENOM) * 4), |
|
|
|
|
|
def _posterize_increasing_level_to_arg(level, hparams): |
|
|
|
|
|
|
|
return 4 - _posterize_level_to_arg(level, hparams)[0], |
|
|
|
|
|
def _posterize_original_level_to_arg(level, _hparams): |
|
|
|
|
|
|
|
return int((level / _LEVEL_DENOM) * 4) + 4, |
|
|
|
|
|
def _solarize_level_to_arg(level, _hparams): |
|
|
|
|
|
return min(256, int((level / _LEVEL_DENOM) * 256)), |
|
|
|
|
|
def _solarize_increasing_level_to_arg(level, _hparams): |
|
|
|
|
|
return 256 - _solarize_level_to_arg(level, _hparams)[0], |
|
|
|
|
|
def _solarize_add_level_to_arg(level, _hparams): |
|
|
|
return min(128, int((level / _LEVEL_DENOM) * 110)), |
|
|
|
|
|
LEVEL_TO_ARG = { |
|
'AutoContrast': None, |
|
'Equalize': None, |
|
'Invert': None, |
|
'Rotate': _rotate_level_to_arg, |
|
|
|
'Posterize': _posterize_level_to_arg, |
|
'PosterizeIncreasing': _posterize_increasing_level_to_arg, |
|
'PosterizeOriginal': _posterize_original_level_to_arg, |
|
'Solarize': _solarize_level_to_arg, |
|
'SolarizeIncreasing': _solarize_increasing_level_to_arg, |
|
'SolarizeAdd': _solarize_add_level_to_arg, |
|
'Color': _enhance_level_to_arg, |
|
'ColorIncreasing': _enhance_increasing_level_to_arg, |
|
'Contrast': _enhance_level_to_arg, |
|
'ContrastIncreasing': _enhance_increasing_level_to_arg, |
|
'Brightness': _enhance_level_to_arg, |
|
'BrightnessIncreasing': _enhance_increasing_level_to_arg, |
|
'Sharpness': _enhance_level_to_arg, |
|
'SharpnessIncreasing': _enhance_increasing_level_to_arg, |
|
'ShearX': _shear_level_to_arg, |
|
'ShearY': _shear_level_to_arg, |
|
'TranslateX': _translate_abs_level_to_arg, |
|
'TranslateY': _translate_abs_level_to_arg, |
|
'TranslateXRel': _translate_rel_level_to_arg, |
|
'TranslateYRel': _translate_rel_level_to_arg, |
|
'Desaturate': partial(_minmax_level_to_arg, min_val=0.5, max_val=1.0), |
|
'GaussianBlur': partial(_minmax_level_to_arg, min_val=0.1, max_val=2.0), |
|
'GaussianBlurRand': _minmax_level_to_arg, |
|
} |
|
|
|
|
|
NAME_TO_OP = { |
|
'AutoContrast': auto_contrast, |
|
'Equalize': equalize, |
|
'Invert': invert, |
|
'Rotate': rotate, |
|
'Posterize': posterize, |
|
'PosterizeIncreasing': posterize, |
|
'PosterizeOriginal': posterize, |
|
'Solarize': solarize, |
|
'SolarizeIncreasing': solarize, |
|
'SolarizeAdd': solarize_add, |
|
'Color': color, |
|
'ColorIncreasing': color, |
|
'Contrast': contrast, |
|
'ContrastIncreasing': contrast, |
|
'Brightness': brightness, |
|
'BrightnessIncreasing': brightness, |
|
'Sharpness': sharpness, |
|
'SharpnessIncreasing': sharpness, |
|
'ShearX': shear_x, |
|
'ShearY': shear_y, |
|
'TranslateX': translate_x_abs, |
|
'TranslateY': translate_y_abs, |
|
'TranslateXRel': translate_x_rel, |
|
'TranslateYRel': translate_y_rel, |
|
'Desaturate': desaturate, |
|
'GaussianBlur': gaussian_blur, |
|
'GaussianBlurRand': gaussian_blur_rand, |
|
} |
|
|
|
|
|
class AugmentOp: |
|
|
|
def __init__(self, name, prob=0.5, magnitude=10, hparams=None): |
|
hparams = hparams or _HPARAMS_DEFAULT |
|
self.name = name |
|
self.aug_fn = NAME_TO_OP[name] |
|
self.level_fn = LEVEL_TO_ARG[name] |
|
self.prob = prob |
|
self.magnitude = magnitude |
|
self.hparams = hparams.copy() |
|
self.kwargs = dict( |
|
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, |
|
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.magnitude_std = self.hparams.get('magnitude_std', 0) |
|
self.magnitude_max = self.hparams.get('magnitude_max', None) |
|
|
|
def __call__(self, img): |
|
if self.prob < 1.0 and random.random() > self.prob: |
|
return img |
|
magnitude = self.magnitude |
|
if self.magnitude_std > 0: |
|
|
|
if self.magnitude_std == float('inf'): |
|
|
|
magnitude = random.uniform(0, magnitude) |
|
elif self.magnitude_std > 0: |
|
magnitude = random.gauss(magnitude, self.magnitude_std) |
|
|
|
|
|
upper_bound = self.magnitude_max or _LEVEL_DENOM |
|
magnitude = max(0., min(magnitude, upper_bound)) |
|
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() |
|
return self.aug_fn(img, *level_args, **self.kwargs) |
|
|
|
def __repr__(self): |
|
fs = self.__class__.__name__ + f'(name={self.name}, p={self.prob}' |
|
fs += f', m={self.magnitude}, mstd={self.magnitude_std}' |
|
if self.magnitude_max is not None: |
|
fs += f', mmax={self.magnitude_max}' |
|
fs += ')' |
|
return fs |
|
|
|
|
|
def auto_augment_policy_v0(hparams): |
|
|
|
policy = [ |
|
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], |
|
[('Color', 0.4, 9), ('Equalize', 0.6, 3)], |
|
[('Color', 0.4, 1), ('Rotate', 0.6, 8)], |
|
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], |
|
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], |
|
[('Color', 0.2, 0), ('Equalize', 0.8, 8)], |
|
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], |
|
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], |
|
[('Color', 0.6, 1), ('Equalize', 1.0, 2)], |
|
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)], |
|
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], |
|
[('Color', 0.4, 7), ('Equalize', 0.6, 0)], |
|
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], |
|
[('Solarize', 0.6, 8), ('Color', 0.6, 9)], |
|
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], |
|
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], |
|
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], |
|
[('ShearY', 0.8, 0), ('Color', 0.6, 4)], |
|
[('Color', 1.0, 0), ('Rotate', 0.6, 2)], |
|
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], |
|
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], |
|
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], |
|
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], |
|
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], |
|
[('Color', 0.8, 6), ('Rotate', 0.4, 5)], |
|
] |
|
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] |
|
return pc |
|
|
|
|
|
def auto_augment_policy_v0r(hparams): |
|
|
|
|
|
policy = [ |
|
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], |
|
[('Color', 0.4, 9), ('Equalize', 0.6, 3)], |
|
[('Color', 0.4, 1), ('Rotate', 0.6, 8)], |
|
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], |
|
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], |
|
[('Color', 0.2, 0), ('Equalize', 0.8, 8)], |
|
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], |
|
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], |
|
[('Color', 0.6, 1), ('Equalize', 1.0, 2)], |
|
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)], |
|
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], |
|
[('Color', 0.4, 7), ('Equalize', 0.6, 0)], |
|
[('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)], |
|
[('Solarize', 0.6, 8), ('Color', 0.6, 9)], |
|
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], |
|
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], |
|
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], |
|
[('ShearY', 0.8, 0), ('Color', 0.6, 4)], |
|
[('Color', 1.0, 0), ('Rotate', 0.6, 2)], |
|
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], |
|
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], |
|
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], |
|
[('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)], |
|
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], |
|
[('Color', 0.8, 6), ('Rotate', 0.4, 5)], |
|
] |
|
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] |
|
return pc |
|
|
|
|
|
def auto_augment_policy_original(hparams): |
|
|
|
policy = [ |
|
[('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)], |
|
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], |
|
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], |
|
[('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)], |
|
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], |
|
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], |
|
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], |
|
[('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)], |
|
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], |
|
[('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)], |
|
[('Rotate', 0.8, 8), ('Color', 0.4, 0)], |
|
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], |
|
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], |
|
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)], |
|
[('Color', 0.6, 4), ('Contrast', 1.0, 8)], |
|
[('Rotate', 0.8, 8), ('Color', 1.0, 2)], |
|
[('Color', 0.8, 8), ('Solarize', 0.8, 7)], |
|
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], |
|
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], |
|
[('Color', 0.4, 0), ('Equalize', 0.6, 3)], |
|
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], |
|
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], |
|
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)], |
|
[('Color', 0.6, 4), ('Contrast', 1.0, 8)], |
|
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], |
|
] |
|
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] |
|
return pc |
|
|
|
|
|
def auto_augment_policy_originalr(hparams): |
|
|
|
policy = [ |
|
[('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)], |
|
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], |
|
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], |
|
[('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)], |
|
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], |
|
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], |
|
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], |
|
[('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)], |
|
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], |
|
[('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)], |
|
[('Rotate', 0.8, 8), ('Color', 0.4, 0)], |
|
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], |
|
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], |
|
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)], |
|
[('Color', 0.6, 4), ('Contrast', 1.0, 8)], |
|
[('Rotate', 0.8, 8), ('Color', 1.0, 2)], |
|
[('Color', 0.8, 8), ('Solarize', 0.8, 7)], |
|
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], |
|
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], |
|
[('Color', 0.4, 0), ('Equalize', 0.6, 3)], |
|
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], |
|
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], |
|
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)], |
|
[('Color', 0.6, 4), ('Contrast', 1.0, 8)], |
|
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], |
|
] |
|
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] |
|
return pc |
|
|
|
|
|
def auto_augment_policy_3a(hparams): |
|
policy = [ |
|
[('Solarize', 1.0, 5)], |
|
[('Desaturate', 1.0, 10)], |
|
[('GaussianBlurRand', 1.0, 10)], |
|
] |
|
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] |
|
return pc |
|
|
|
|
|
def auto_augment_policy(name='v0', hparams=None): |
|
hparams = hparams or _HPARAMS_DEFAULT |
|
if name == 'original': |
|
return auto_augment_policy_original(hparams) |
|
if name == 'originalr': |
|
return auto_augment_policy_originalr(hparams) |
|
if name == 'v0': |
|
return auto_augment_policy_v0(hparams) |
|
if name == 'v0r': |
|
return auto_augment_policy_v0r(hparams) |
|
if name == '3a': |
|
return auto_augment_policy_3a(hparams) |
|
assert False, f'Unknown AA policy {name}' |
|
|
|
|
|
class AutoAugment: |
|
|
|
def __init__(self, policy): |
|
self.policy = policy |
|
|
|
def __call__(self, img): |
|
sub_policy = random.choice(self.policy) |
|
for op in sub_policy: |
|
img = op(img) |
|
return img |
|
|
|
def __repr__(self): |
|
fs = self.__class__.__name__ + '(policy=' |
|
for p in self.policy: |
|
fs += '\n\t[' |
|
fs += ', '.join([str(op) for op in p]) |
|
fs += ']' |
|
fs += ')' |
|
return fs |
|
|
|
|
|
def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None): |
|
""" |
|
Create a AutoAugment transform |
|
|
|
Args: |
|
config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by |
|
dashes ('-'). |
|
The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr'). |
|
|
|
The remaining sections: |
|
'mstd' - float std deviation of magnitude noise applied |
|
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5 |
|
|
|
hparams: Other hparams (kwargs) for the AutoAugmentation scheme |
|
|
|
Returns: |
|
A PyTorch compatible Transform |
|
""" |
|
config = config_str.split('-') |
|
policy_name = config[0] |
|
config = config[1:] |
|
for c in config: |
|
cs = re.split(r'(\d.*)', c) |
|
if len(cs) < 2: |
|
continue |
|
key, val = cs[:2] |
|
if key == 'mstd': |
|
|
|
hparams.setdefault('magnitude_std', float(val)) |
|
else: |
|
assert False, 'Unknown AutoAugment config section' |
|
aa_policy = auto_augment_policy(policy_name, hparams=hparams) |
|
return AutoAugment(aa_policy) |
|
|
|
|
|
_RAND_TRANSFORMS = [ |
|
'AutoContrast', |
|
'Equalize', |
|
'Invert', |
|
'Rotate', |
|
'Posterize', |
|
'Solarize', |
|
'SolarizeAdd', |
|
'Color', |
|
'Contrast', |
|
'Brightness', |
|
'Sharpness', |
|
'ShearX', |
|
'ShearY', |
|
'TranslateXRel', |
|
'TranslateYRel', |
|
|
|
] |
|
|
|
|
|
_RAND_INCREASING_TRANSFORMS = [ |
|
'AutoContrast', |
|
'Equalize', |
|
'Invert', |
|
'Rotate', |
|
'PosterizeIncreasing', |
|
'SolarizeIncreasing', |
|
'SolarizeAdd', |
|
'ColorIncreasing', |
|
'ContrastIncreasing', |
|
'BrightnessIncreasing', |
|
'SharpnessIncreasing', |
|
'ShearX', |
|
'ShearY', |
|
'TranslateXRel', |
|
'TranslateYRel', |
|
|
|
] |
|
|
|
|
|
_RAND_3A = [ |
|
'SolarizeIncreasing', |
|
'Desaturate', |
|
'GaussianBlur', |
|
] |
|
|
|
|
|
_RAND_WEIGHTED_3A = { |
|
'SolarizeIncreasing': 6, |
|
'Desaturate': 6, |
|
'GaussianBlur': 6, |
|
'Rotate': 3, |
|
'ShearX': 2, |
|
'ShearY': 2, |
|
'PosterizeIncreasing': 1, |
|
'AutoContrast': 1, |
|
'ColorIncreasing': 1, |
|
'SharpnessIncreasing': 1, |
|
'ContrastIncreasing': 1, |
|
'BrightnessIncreasing': 1, |
|
'Equalize': 1, |
|
'Invert': 1, |
|
} |
|
|
|
|
|
|
|
|
|
_RAND_WEIGHTED_0 = { |
|
'Rotate': 3, |
|
'ShearX': 2, |
|
'ShearY': 2, |
|
'TranslateXRel': 1, |
|
'TranslateYRel': 1, |
|
'ColorIncreasing': .25, |
|
'SharpnessIncreasing': 0.25, |
|
'AutoContrast': 0.25, |
|
'SolarizeIncreasing': .05, |
|
'SolarizeAdd': .05, |
|
'ContrastIncreasing': .05, |
|
'BrightnessIncreasing': .05, |
|
'Equalize': .05, |
|
'PosterizeIncreasing': 0.05, |
|
'Invert': 0.05, |
|
} |
|
|
|
|
|
def _get_weighted_transforms(transforms: Dict): |
|
transforms, probs = list(zip(*transforms.items())) |
|
probs = np.array(probs) |
|
probs = probs / np.sum(probs) |
|
return transforms, probs |
|
|
|
|
|
def rand_augment_choices(name: str, increasing=True): |
|
if name == 'weights': |
|
return _RAND_WEIGHTED_0 |
|
if name == '3aw': |
|
return _RAND_WEIGHTED_3A |
|
if name == '3a': |
|
return _RAND_3A |
|
return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS |
|
|
|
|
|
def rand_augment_ops( |
|
magnitude: Union[int, float] = 10, |
|
prob: float = 0.5, |
|
hparams: Optional[Dict] = None, |
|
transforms: Optional[Union[Dict, List]] = None, |
|
): |
|
hparams = hparams or _HPARAMS_DEFAULT |
|
transforms = transforms or _RAND_TRANSFORMS |
|
return [AugmentOp( |
|
name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms] |
|
|
|
|
|
class RandAugment: |
|
def __init__(self, ops, num_layers=2, choice_weights=None): |
|
self.ops = ops |
|
self.num_layers = num_layers |
|
self.choice_weights = choice_weights |
|
|
|
def __call__(self, img): |
|
|
|
ops = np.random.choice( |
|
self.ops, |
|
self.num_layers, |
|
replace=self.choice_weights is None, |
|
p=self.choice_weights, |
|
) |
|
for op in ops: |
|
img = op(img) |
|
return img |
|
|
|
def __repr__(self): |
|
fs = self.__class__.__name__ + f'(n={self.num_layers}, ops=' |
|
for op in self.ops: |
|
fs += f'\n\t{op}' |
|
fs += ')' |
|
return fs |
|
|
|
|
|
def rand_augment_transform( |
|
config_str: str, |
|
hparams: Optional[Dict] = None, |
|
transforms: Optional[Union[str, Dict, List]] = None, |
|
): |
|
""" |
|
Create a RandAugment transform |
|
|
|
Args: |
|
config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated |
|
by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). |
|
The remaining sections, not order sepecific determine |
|
'm' - integer magnitude of rand augment |
|
'n' - integer num layers (number of transform ops selected per image) |
|
'p' - float probability of applying each layer (default 0.5) |
|
'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100) |
|
'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10) |
|
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) |
|
't' - str name of transform set to use |
|
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 |
|
'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2 |
|
|
|
hparams (dict): Other hparams (kwargs) for the RandAugmentation scheme |
|
|
|
Returns: |
|
A PyTorch compatible Transform |
|
""" |
|
magnitude = _LEVEL_DENOM |
|
num_layers = 2 |
|
increasing = False |
|
prob = 0.5 |
|
config = config_str.split('-') |
|
assert config[0] == 'rand' |
|
config = config[1:] |
|
for c in config: |
|
if c.startswith('t'): |
|
|
|
val = str(c[1:]) |
|
if transforms is None: |
|
transforms = val |
|
else: |
|
|
|
cs = re.split(r'(\d.*)', c) |
|
if len(cs) < 2: |
|
continue |
|
key, val = cs[:2] |
|
if key == 'mstd': |
|
|
|
mstd = float(val) |
|
if mstd > 100: |
|
|
|
mstd = float('inf') |
|
hparams.setdefault('magnitude_std', mstd) |
|
elif key == 'mmax': |
|
|
|
hparams.setdefault('magnitude_max', int(val)) |
|
elif key == 'inc': |
|
if bool(val): |
|
increasing = True |
|
elif key == 'm': |
|
magnitude = int(val) |
|
elif key == 'n': |
|
num_layers = int(val) |
|
elif key == 'p': |
|
prob = float(val) |
|
else: |
|
assert False, 'Unknown RandAugment config section' |
|
|
|
if isinstance(transforms, str): |
|
transforms = rand_augment_choices(transforms, increasing=increasing) |
|
elif transforms is None: |
|
transforms = _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS |
|
|
|
choice_weights = None |
|
if isinstance(transforms, Dict): |
|
transforms, choice_weights = _get_weighted_transforms(transforms) |
|
|
|
ra_ops = rand_augment_ops(magnitude=magnitude, prob=prob, hparams=hparams, transforms=transforms) |
|
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) |
|
|
|
|
|
_AUGMIX_TRANSFORMS = [ |
|
'AutoContrast', |
|
'ColorIncreasing', |
|
'ContrastIncreasing', |
|
'BrightnessIncreasing', |
|
'SharpnessIncreasing', |
|
'Equalize', |
|
'Rotate', |
|
'PosterizeIncreasing', |
|
'SolarizeIncreasing', |
|
'ShearX', |
|
'ShearY', |
|
'TranslateXRel', |
|
'TranslateYRel', |
|
] |
|
|
|
|
|
def augmix_ops( |
|
magnitude: Union[int, float] = 10, |
|
hparams: Optional[Dict] = None, |
|
transforms: Optional[Union[str, Dict, List]] = None, |
|
): |
|
hparams = hparams or _HPARAMS_DEFAULT |
|
transforms = transforms or _AUGMIX_TRANSFORMS |
|
return [AugmentOp( |
|
name, |
|
prob=1.0, |
|
magnitude=magnitude, |
|
hparams=hparams |
|
) for name in transforms] |
|
|
|
|
|
class AugMixAugment: |
|
""" AugMix Transform |
|
Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py |
|
From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - |
|
https://arxiv.org/abs/1912.02781 |
|
""" |
|
def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False): |
|
self.ops = ops |
|
self.alpha = alpha |
|
self.width = width |
|
self.depth = depth |
|
self.blended = blended |
|
|
|
def _calc_blended_weights(self, ws, m): |
|
ws = ws * m |
|
cump = 1. |
|
rws = [] |
|
for w in ws[::-1]: |
|
alpha = w / cump |
|
cump *= (1 - alpha) |
|
rws.append(alpha) |
|
return np.array(rws[::-1], dtype=np.float32) |
|
|
|
def _apply_blended(self, img, mixing_weights, m): |
|
|
|
|
|
|
|
|
|
img_orig = img.copy() |
|
ws = self._calc_blended_weights(mixing_weights, m) |
|
for w in ws: |
|
depth = self.depth if self.depth > 0 else np.random.randint(1, 4) |
|
ops = np.random.choice(self.ops, depth, replace=True) |
|
img_aug = img_orig |
|
for op in ops: |
|
img_aug = op(img_aug) |
|
img = Image.blend(img, img_aug, w) |
|
return img |
|
|
|
def _apply_basic(self, img, mixing_weights, m): |
|
|
|
|
|
|
|
img_shape = img.size[0], img.size[1], len(img.getbands()) |
|
mixed = np.zeros(img_shape, dtype=np.float32) |
|
for mw in mixing_weights: |
|
depth = self.depth if self.depth > 0 else np.random.randint(1, 4) |
|
ops = np.random.choice(self.ops, depth, replace=True) |
|
img_aug = img |
|
for op in ops: |
|
img_aug = op(img_aug) |
|
mixed += mw * np.asarray(img_aug, dtype=np.float32) |
|
np.clip(mixed, 0, 255., out=mixed) |
|
mixed = Image.fromarray(mixed.astype(np.uint8)) |
|
return Image.blend(img, mixed, m) |
|
|
|
def __call__(self, img): |
|
mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width)) |
|
m = np.float32(np.random.beta(self.alpha, self.alpha)) |
|
if self.blended: |
|
mixed = self._apply_blended(img, mixing_weights, m) |
|
else: |
|
mixed = self._apply_basic(img, mixing_weights, m) |
|
return mixed |
|
|
|
def __repr__(self): |
|
fs = self.__class__.__name__ + f'(alpha={self.alpha}, width={self.width}, depth={self.depth}, ops=' |
|
for op in self.ops: |
|
fs += f'\n\t{op}' |
|
fs += ')' |
|
return fs |
|
|
|
|
|
def augment_and_mix_transform(config_str: str, hparams: Optional[Dict] = None): |
|
""" Create AugMix PyTorch transform |
|
|
|
Args: |
|
config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated |
|
by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). |
|
The remaining sections, not order sepecific determine |
|
'm' - integer magnitude (severity) of augmentation mix (default: 3) |
|
'w' - integer width of augmentation chain (default: 3) |
|
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) |
|
'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0) |
|
'mstd' - float std deviation of magnitude noise applied (default: 0) |
|
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 |
|
|
|
hparams: Other hparams (kwargs) for the Augmentation transforms |
|
|
|
Returns: |
|
A PyTorch compatible Transform |
|
""" |
|
magnitude = 3 |
|
width = 3 |
|
depth = -1 |
|
alpha = 1. |
|
blended = False |
|
config = config_str.split('-') |
|
assert config[0] == 'augmix' |
|
config = config[1:] |
|
for c in config: |
|
cs = re.split(r'(\d.*)', c) |
|
if len(cs) < 2: |
|
continue |
|
key, val = cs[:2] |
|
if key == 'mstd': |
|
|
|
hparams.setdefault('magnitude_std', float(val)) |
|
elif key == 'm': |
|
magnitude = int(val) |
|
elif key == 'w': |
|
width = int(val) |
|
elif key == 'd': |
|
depth = int(val) |
|
elif key == 'a': |
|
alpha = float(val) |
|
elif key == 'b': |
|
blended = bool(val) |
|
else: |
|
assert False, 'Unknown AugMix config section' |
|
hparams.setdefault('magnitude_std', float('inf')) |
|
ops = augmix_ops(magnitude=magnitude, hparams=hparams) |
|
return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended) |
|
|