|
|
|
|
|
|
|
import inspect
|
|
import numpy as np
|
|
import pprint
|
|
from typing import Any, List, Optional, Tuple, Union
|
|
from fvcore.transforms.transform import Transform, TransformList
|
|
|
|
"""
|
|
See "Data Augmentation" tutorial for an overview of the system:
|
|
https://detectron2.readthedocs.io/tutorials/augmentation.html
|
|
"""
|
|
|
|
|
|
__all__ = [
|
|
"Augmentation",
|
|
"AugmentationList",
|
|
"AugInput",
|
|
"TransformGen",
|
|
"apply_transform_gens",
|
|
"StandardAugInput",
|
|
"apply_augmentations",
|
|
]
|
|
|
|
|
|
def _check_img_dtype(img):
|
|
assert isinstance(img, np.ndarray), "[Augmentation] Needs an numpy array, but got a {}!".format(
|
|
type(img)
|
|
)
|
|
assert not isinstance(img.dtype, np.integer) or (
|
|
img.dtype == np.uint8
|
|
), "[Augmentation] Got image of type {}, use uint8 or floating points instead!".format(
|
|
img.dtype
|
|
)
|
|
assert img.ndim in [2, 3], img.ndim
|
|
|
|
|
|
def _get_aug_input_args(aug, aug_input) -> List[Any]:
|
|
"""
|
|
Get the arguments to be passed to ``aug.get_transform`` from the input ``aug_input``.
|
|
"""
|
|
if aug.input_args is None:
|
|
|
|
prms = list(inspect.signature(aug.get_transform).parameters.items())
|
|
|
|
|
|
|
|
if len(prms) == 1:
|
|
names = ("image",)
|
|
else:
|
|
names = []
|
|
for name, prm in prms:
|
|
if prm.kind in (
|
|
inspect.Parameter.VAR_POSITIONAL,
|
|
inspect.Parameter.VAR_KEYWORD,
|
|
):
|
|
raise TypeError(
|
|
f""" \
|
|
The default implementation of `{type(aug)}.__call__` does not allow \
|
|
`{type(aug)}.get_transform` to use variable-length arguments (*args, **kwargs)! \
|
|
If arguments are unknown, reimplement `__call__` instead. \
|
|
"""
|
|
)
|
|
names.append(name)
|
|
aug.input_args = tuple(names)
|
|
|
|
args = []
|
|
for f in aug.input_args:
|
|
try:
|
|
args.append(getattr(aug_input, f))
|
|
except AttributeError as e:
|
|
raise AttributeError(
|
|
f"{type(aug)}.get_transform needs input attribute '{f}', "
|
|
f"but it is not an attribute of {type(aug_input)}!"
|
|
) from e
|
|
return args
|
|
|
|
|
|
class Augmentation:
|
|
"""
|
|
Augmentation defines (often random) policies/strategies to generate :class:`Transform`
|
|
from data. It is often used for pre-processing of input data.
|
|
|
|
A "policy" that generates a :class:`Transform` may, in the most general case,
|
|
need arbitrary information from input data in order to determine what transforms
|
|
to apply. Therefore, each :class:`Augmentation` instance defines the arguments
|
|
needed by its :meth:`get_transform` method. When called with the positional arguments,
|
|
the :meth:`get_transform` method executes the policy.
|
|
|
|
Note that :class:`Augmentation` defines the policies to create a :class:`Transform`,
|
|
but not how to execute the actual transform operations to those data.
|
|
Its :meth:`__call__` method will use :meth:`AugInput.transform` to execute the transform.
|
|
|
|
The returned `Transform` object is meant to describe deterministic transformation, which means
|
|
it can be re-applied on associated data, e.g. the geometry of an image and its segmentation
|
|
masks need to be transformed together.
|
|
(If such re-application is not needed, then determinism is not a crucial requirement.)
|
|
"""
|
|
|
|
input_args: Optional[Tuple[str]] = None
|
|
"""
|
|
Stores the attribute names needed by :meth:`get_transform`, e.g. ``("image", "sem_seg")``.
|
|
By default, it is just a tuple of argument names in :meth:`self.get_transform`, which often only
|
|
contain "image". As long as the argument name convention is followed, there is no need for
|
|
users to touch this attribute.
|
|
"""
|
|
|
|
def _init(self, params=None):
|
|
if params:
|
|
for k, v in params.items():
|
|
if k != "self" and not k.startswith("_"):
|
|
setattr(self, k, v)
|
|
|
|
def get_transform(self, *args) -> Transform:
|
|
"""
|
|
Execute the policy based on input data, and decide what transform to apply to inputs.
|
|
|
|
Args:
|
|
args: Any fixed-length positional arguments. By default, the name of the arguments
|
|
should exist in the :class:`AugInput` to be used.
|
|
|
|
Returns:
|
|
Transform: Returns the deterministic transform to apply to the input.
|
|
|
|
Examples:
|
|
::
|
|
class MyAug:
|
|
# if a policy needs to know both image and semantic segmentation
|
|
def get_transform(image, sem_seg) -> T.Transform:
|
|
pass
|
|
tfm: Transform = MyAug().get_transform(image, sem_seg)
|
|
new_image = tfm.apply_image(image)
|
|
|
|
Notes:
|
|
Users can freely use arbitrary new argument names in custom
|
|
:meth:`get_transform` method, as long as they are available in the
|
|
input data. In detectron2 we use the following convention:
|
|
|
|
* image: (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or
|
|
floating point in range [0, 1] or [0, 255].
|
|
* boxes: (N,4) ndarray of float32. It represents the instance bounding boxes
|
|
of N instances. Each is in XYXY format in unit of absolute coordinates.
|
|
* sem_seg: (H,W) ndarray of type uint8. Each element is an integer label of pixel.
|
|
|
|
We do not specify convention for other types and do not include builtin
|
|
:class:`Augmentation` that uses other types in detectron2.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def __call__(self, aug_input) -> Transform:
|
|
"""
|
|
Augment the given `aug_input` **in-place**, and return the transform that's used.
|
|
|
|
This method will be called to apply the augmentation. In most augmentation, it
|
|
is enough to use the default implementation, which calls :meth:`get_transform`
|
|
using the inputs. But a subclass can overwrite it to have more complicated logic.
|
|
|
|
Args:
|
|
aug_input (AugInput): an object that has attributes needed by this augmentation
|
|
(defined by ``self.get_transform``). Its ``transform`` method will be called
|
|
to in-place transform it.
|
|
|
|
Returns:
|
|
Transform: the transform that is applied on the input.
|
|
"""
|
|
args = _get_aug_input_args(self, aug_input)
|
|
tfm = self.get_transform(*args)
|
|
assert isinstance(tfm, (Transform, TransformList)), (
|
|
f"{type(self)}.get_transform must return an instance of Transform! "
|
|
f"Got {type(tfm)} instead."
|
|
)
|
|
aug_input.transform(tfm)
|
|
return tfm
|
|
|
|
def _rand_range(self, low=1.0, high=None, size=None):
|
|
"""
|
|
Uniform float random number between low and high.
|
|
"""
|
|
if high is None:
|
|
low, high = 0, low
|
|
if size is None:
|
|
size = []
|
|
return np.random.uniform(low, high, size)
|
|
|
|
def __repr__(self):
|
|
"""
|
|
Produce something like:
|
|
"MyAugmentation(field1={self.field1}, field2={self.field2})"
|
|
"""
|
|
try:
|
|
sig = inspect.signature(self.__init__)
|
|
classname = type(self).__name__
|
|
argstr = []
|
|
for name, param in sig.parameters.items():
|
|
assert (
|
|
param.kind != param.VAR_POSITIONAL and param.kind != param.VAR_KEYWORD
|
|
), "The default __repr__ doesn't support *args or **kwargs"
|
|
assert hasattr(self, name), (
|
|
"Attribute {} not found! "
|
|
"Default __repr__ only works if attributes match the constructor.".format(name)
|
|
)
|
|
attr = getattr(self, name)
|
|
default = param.default
|
|
if default is attr:
|
|
continue
|
|
attr_str = pprint.pformat(attr)
|
|
if "\n" in attr_str:
|
|
|
|
attr_str = "..."
|
|
argstr.append("{}={}".format(name, attr_str))
|
|
return "{}({})".format(classname, ", ".join(argstr))
|
|
except AssertionError:
|
|
return super().__repr__()
|
|
|
|
__str__ = __repr__
|
|
|
|
|
|
class _TransformToAug(Augmentation):
|
|
def __init__(self, tfm: Transform):
|
|
self.tfm = tfm
|
|
|
|
def get_transform(self, *args):
|
|
return self.tfm
|
|
|
|
def __repr__(self):
|
|
return repr(self.tfm)
|
|
|
|
__str__ = __repr__
|
|
|
|
|
|
def _transform_to_aug(tfm_or_aug):
|
|
"""
|
|
Wrap Transform into Augmentation.
|
|
Private, used internally to implement augmentations.
|
|
"""
|
|
assert isinstance(tfm_or_aug, (Transform, Augmentation)), tfm_or_aug
|
|
if isinstance(tfm_or_aug, Augmentation):
|
|
return tfm_or_aug
|
|
else:
|
|
return _TransformToAug(tfm_or_aug)
|
|
|
|
|
|
class AugmentationList(Augmentation):
|
|
"""
|
|
Apply a sequence of augmentations.
|
|
|
|
It has ``__call__`` method to apply the augmentations.
|
|
|
|
Note that :meth:`get_transform` method is impossible (will throw error if called)
|
|
for :class:`AugmentationList`, because in order to apply a sequence of augmentations,
|
|
the kth augmentation must be applied first, to provide inputs needed by the (k+1)th
|
|
augmentation.
|
|
"""
|
|
|
|
def __init__(self, augs):
|
|
"""
|
|
Args:
|
|
augs (list[Augmentation or Transform]):
|
|
"""
|
|
super().__init__()
|
|
self.augs = [_transform_to_aug(x) for x in augs]
|
|
|
|
def __call__(self, aug_input) -> TransformList:
|
|
tfms = []
|
|
for x in self.augs:
|
|
tfm = x(aug_input)
|
|
tfms.append(tfm)
|
|
return TransformList(tfms)
|
|
|
|
def __repr__(self):
|
|
msgs = [str(x) for x in self.augs]
|
|
return "AugmentationList[{}]".format(", ".join(msgs))
|
|
|
|
__str__ = __repr__
|
|
|
|
|
|
class AugInput:
|
|
"""
|
|
Input that can be used with :meth:`Augmentation.__call__`.
|
|
This is a standard implementation for the majority of use cases.
|
|
This class provides the standard attributes **"image", "boxes", "sem_seg"**
|
|
defined in :meth:`__init__` and they may be needed by different augmentations.
|
|
Most augmentation policies do not need attributes beyond these three.
|
|
|
|
After applying augmentations to these attributes (using :meth:`AugInput.transform`),
|
|
the returned transforms can then be used to transform other data structures that users have.
|
|
|
|
Examples:
|
|
::
|
|
input = AugInput(image, boxes=boxes)
|
|
tfms = augmentation(input)
|
|
transformed_image = input.image
|
|
transformed_boxes = input.boxes
|
|
transformed_other_data = tfms.apply_other(other_data)
|
|
|
|
An extended project that works with new data types may implement augmentation policies
|
|
that need other inputs. An algorithm may need to transform inputs in a way different
|
|
from the standard approach defined in this class. In those rare situations, users can
|
|
implement a class similar to this class, that satify the following condition:
|
|
|
|
* The input must provide access to these data in the form of attribute access
|
|
(``getattr``). For example, if an :class:`Augmentation` to be applied needs "image"
|
|
and "sem_seg" arguments, its input must have the attribute "image" and "sem_seg".
|
|
* The input must have a ``transform(tfm: Transform) -> None`` method which
|
|
in-place transforms all its attributes.
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
self,
|
|
image: np.ndarray,
|
|
*,
|
|
boxes: Optional[np.ndarray] = None,
|
|
sem_seg: Optional[np.ndarray] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
image (ndarray): (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or
|
|
floating point in range [0, 1] or [0, 255]. The meaning of C is up
|
|
to users.
|
|
boxes (ndarray or None): Nx4 float32 boxes in XYXY_ABS mode
|
|
sem_seg (ndarray or None): HxW uint8 semantic segmentation mask. Each element
|
|
is an integer label of pixel.
|
|
"""
|
|
_check_img_dtype(image)
|
|
self.image = image
|
|
self.boxes = boxes
|
|
self.sem_seg = sem_seg
|
|
|
|
def transform(self, tfm: Transform) -> None:
|
|
"""
|
|
In-place transform all attributes of this class.
|
|
|
|
By "in-place", it means after calling this method, accessing an attribute such
|
|
as ``self.image`` will return transformed data.
|
|
"""
|
|
self.image = tfm.apply_image(self.image)
|
|
if self.boxes is not None:
|
|
self.boxes = tfm.apply_box(self.boxes)
|
|
if self.sem_seg is not None:
|
|
self.sem_seg = tfm.apply_segmentation(self.sem_seg)
|
|
|
|
def apply_augmentations(
|
|
self, augmentations: List[Union[Augmentation, Transform]]
|
|
) -> TransformList:
|
|
"""
|
|
Equivalent of ``AugmentationList(augmentations)(self)``
|
|
"""
|
|
return AugmentationList(augmentations)(self)
|
|
|
|
|
|
def apply_augmentations(augmentations: List[Union[Transform, Augmentation]], inputs):
|
|
"""
|
|
Use ``T.AugmentationList(augmentations)(inputs)`` instead.
|
|
"""
|
|
if isinstance(inputs, np.ndarray):
|
|
|
|
image_only = True
|
|
inputs = AugInput(inputs)
|
|
else:
|
|
image_only = False
|
|
tfms = inputs.apply_augmentations(augmentations)
|
|
return inputs.image if image_only else inputs, tfms
|
|
|
|
|
|
apply_transform_gens = apply_augmentations
|
|
"""
|
|
Alias for backward-compatibility.
|
|
"""
|
|
|
|
TransformGen = Augmentation
|
|
"""
|
|
Alias for Augmentation, since it is something that generates :class:`Transform`s
|
|
"""
|
|
|
|
StandardAugInput = AugInput
|
|
"""
|
|
Alias for compatibility. It's not worth the complexity to have two classes.
|
|
"""
|
|
|