|
|
|
import itertools
|
|
import warnings
|
|
from typing import Any, Dict, List, Tuple, Union
|
|
import torch
|
|
|
|
|
|
class Instances:
|
|
"""
|
|
This class represents a list of instances in an image.
|
|
It stores the attributes of instances (e.g., boxes, masks, labels, scores) as "fields".
|
|
All fields must have the same ``__len__`` which is the number of instances.
|
|
|
|
All other (non-field) attributes of this class are considered private:
|
|
they must start with '_' and are not modifiable by a user.
|
|
|
|
Some basic usage:
|
|
|
|
1. Set/get/check a field:
|
|
|
|
.. code-block:: python
|
|
|
|
instances.gt_boxes = Boxes(...)
|
|
print(instances.pred_masks) # a tensor of shape (N, H, W)
|
|
print('gt_masks' in instances)
|
|
|
|
2. ``len(instances)`` returns the number of instances
|
|
3. Indexing: ``instances[indices]`` will apply the indexing on all the fields
|
|
and returns a new :class:`Instances`.
|
|
Typically, ``indices`` is a integer vector of indices,
|
|
or a binary mask of length ``num_instances``
|
|
|
|
.. code-block:: python
|
|
|
|
category_3_detections = instances[instances.pred_classes == 3]
|
|
confident_detections = instances[instances.scores > 0.9]
|
|
"""
|
|
|
|
def __init__(self, image_size: Tuple[int, int], **kwargs: Any):
|
|
"""
|
|
Args:
|
|
image_size (height, width): the spatial size of the image.
|
|
kwargs: fields to add to this `Instances`.
|
|
"""
|
|
self._image_size = image_size
|
|
self._fields: Dict[str, Any] = {}
|
|
for k, v in kwargs.items():
|
|
self.set(k, v)
|
|
|
|
@property
|
|
def image_size(self) -> Tuple[int, int]:
|
|
"""
|
|
Returns:
|
|
tuple: height, width
|
|
"""
|
|
return self._image_size
|
|
|
|
def __setattr__(self, name: str, val: Any) -> None:
|
|
if name.startswith("_"):
|
|
super().__setattr__(name, val)
|
|
else:
|
|
self.set(name, val)
|
|
|
|
def __getattr__(self, name: str) -> Any:
|
|
if name == "_fields" or name not in self._fields:
|
|
raise AttributeError("Cannot find field '{}' in the given Instances!".format(name))
|
|
return self._fields[name]
|
|
|
|
def set(self, name: str, value: Any) -> None:
|
|
"""
|
|
Set the field named `name` to `value`.
|
|
The length of `value` must be the number of instances,
|
|
and must agree with other existing fields in this object.
|
|
"""
|
|
with warnings.catch_warnings(record=True):
|
|
data_len = len(value)
|
|
if len(self._fields):
|
|
assert (
|
|
len(self) == data_len
|
|
), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self))
|
|
self._fields[name] = value
|
|
|
|
def has(self, name: str) -> bool:
|
|
"""
|
|
Returns:
|
|
bool: whether the field called `name` exists.
|
|
"""
|
|
return name in self._fields
|
|
|
|
def remove(self, name: str) -> None:
|
|
"""
|
|
Remove the field called `name`.
|
|
"""
|
|
del self._fields[name]
|
|
|
|
def get(self, name: str) -> Any:
|
|
"""
|
|
Returns the field called `name`.
|
|
"""
|
|
return self._fields[name]
|
|
|
|
def get_fields(self) -> Dict[str, Any]:
|
|
"""
|
|
Returns:
|
|
dict: a dict which maps names (str) to data of the fields
|
|
|
|
Modifying the returned dict will modify this instance.
|
|
"""
|
|
return self._fields
|
|
|
|
|
|
def to(self, *args: Any, **kwargs: Any) -> "Instances":
|
|
"""
|
|
Returns:
|
|
Instances: all fields are called with a `to(device)`, if the field has this method.
|
|
"""
|
|
ret = Instances(self._image_size)
|
|
for k, v in self._fields.items():
|
|
if hasattr(v, "to"):
|
|
v = v.to(*args, **kwargs)
|
|
ret.set(k, v)
|
|
return ret
|
|
|
|
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Instances":
|
|
"""
|
|
Args:
|
|
item: an index-like object and will be used to index all the fields.
|
|
|
|
Returns:
|
|
If `item` is a string, return the data in the corresponding field.
|
|
Otherwise, returns an `Instances` where all fields are indexed by `item`.
|
|
"""
|
|
if type(item) == int:
|
|
if item >= len(self) or item < -len(self):
|
|
raise IndexError("Instances index out of range!")
|
|
else:
|
|
item = slice(item, None, len(self))
|
|
|
|
ret = Instances(self._image_size)
|
|
for k, v in self._fields.items():
|
|
ret.set(k, v[item])
|
|
return ret
|
|
|
|
def __len__(self) -> int:
|
|
for v in self._fields.values():
|
|
|
|
return v.__len__()
|
|
raise NotImplementedError("Empty Instances does not support __len__!")
|
|
|
|
def __iter__(self):
|
|
raise NotImplementedError("`Instances` object is not iterable!")
|
|
|
|
@staticmethod
|
|
def cat(instance_lists: List["Instances"]) -> "Instances":
|
|
"""
|
|
Args:
|
|
instance_lists (list[Instances])
|
|
|
|
Returns:
|
|
Instances
|
|
"""
|
|
assert all(isinstance(i, Instances) for i in instance_lists)
|
|
assert len(instance_lists) > 0
|
|
if len(instance_lists) == 1:
|
|
return instance_lists[0]
|
|
|
|
image_size = instance_lists[0].image_size
|
|
if not isinstance(image_size, torch.Tensor):
|
|
for i in instance_lists[1:]:
|
|
assert i.image_size == image_size
|
|
ret = Instances(image_size)
|
|
for k in instance_lists[0]._fields.keys():
|
|
values = [i.get(k) for i in instance_lists]
|
|
v0 = values[0]
|
|
if isinstance(v0, torch.Tensor):
|
|
values = torch.cat(values, dim=0)
|
|
elif isinstance(v0, list):
|
|
values = list(itertools.chain(*values))
|
|
elif hasattr(type(v0), "cat"):
|
|
values = type(v0).cat(values)
|
|
else:
|
|
raise ValueError("Unsupported type {} for concatenation".format(type(v0)))
|
|
ret.set(k, values)
|
|
return ret
|
|
|
|
def __str__(self) -> str:
|
|
s = self.__class__.__name__ + "("
|
|
s += "num_instances={}, ".format(len(self))
|
|
s += "image_height={}, ".format(self._image_size[0])
|
|
s += "image_width={}, ".format(self._image_size[1])
|
|
s += "fields=[{}])".format(", ".join((f"{k}: {v}" for k, v in self._fields.items())))
|
|
return s
|
|
|
|
__repr__ = __str__
|
|
|