|
|
|
import itertools
|
|
import logging
|
|
import numpy as np
|
|
from collections import OrderedDict
|
|
from collections.abc import Mapping
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
import torch
|
|
from omegaconf import DictConfig, OmegaConf
|
|
from torch import Tensor, nn
|
|
|
|
from detectron2.layers import ShapeSpec
|
|
from detectron2.structures import BitMasks, Boxes, ImageList, Instances
|
|
from detectron2.utils.events import get_event_storage
|
|
|
|
from .backbone import Backbone
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _to_container(cfg):
|
|
"""
|
|
mmdet will assert the type of dict/list.
|
|
So convert omegaconf objects to dict/list.
|
|
"""
|
|
if isinstance(cfg, DictConfig):
|
|
cfg = OmegaConf.to_container(cfg, resolve=True)
|
|
from mmcv.utils import ConfigDict
|
|
|
|
return ConfigDict(cfg)
|
|
|
|
|
|
class MMDetBackbone(Backbone):
|
|
"""
|
|
Wrapper of mmdetection backbones to use in detectron2.
|
|
|
|
mmdet backbones produce list/tuple of tensors, while detectron2 backbones
|
|
produce a dict of tensors. This class wraps the given backbone to produce
|
|
output in detectron2's convention, so it can be used in place of detectron2
|
|
backbones.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
backbone: Union[nn.Module, Mapping],
|
|
neck: Union[nn.Module, Mapping, None] = None,
|
|
*,
|
|
output_shapes: List[ShapeSpec],
|
|
output_names: Optional[List[str]] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
backbone: either a backbone module or a mmdet config dict that defines a
|
|
backbone. The backbone takes a 4D image tensor and returns a
|
|
sequence of tensors.
|
|
neck: either a backbone module or a mmdet config dict that defines a
|
|
neck. The neck takes outputs of backbone and returns a
|
|
sequence of tensors. If None, no neck is used.
|
|
output_shapes: shape for every output of the backbone (or neck, if given).
|
|
stride and channels are often needed.
|
|
output_names: names for every output of the backbone (or neck, if given).
|
|
By default, will use "out0", "out1", ...
|
|
"""
|
|
super().__init__()
|
|
if isinstance(backbone, Mapping):
|
|
from mmdet.models import build_backbone
|
|
|
|
backbone = build_backbone(_to_container(backbone))
|
|
self.backbone = backbone
|
|
|
|
if isinstance(neck, Mapping):
|
|
from mmdet.models import build_neck
|
|
|
|
neck = build_neck(_to_container(neck))
|
|
self.neck = neck
|
|
|
|
|
|
|
|
|
|
logger.info("Initializing mmdet backbone weights...")
|
|
self.backbone.init_weights()
|
|
|
|
|
|
|
|
self.backbone.train()
|
|
if self.neck is not None:
|
|
logger.info("Initializing mmdet neck weights ...")
|
|
if isinstance(self.neck, nn.Sequential):
|
|
for m in self.neck:
|
|
m.init_weights()
|
|
else:
|
|
self.neck.init_weights()
|
|
self.neck.train()
|
|
|
|
self._output_shapes = output_shapes
|
|
if not output_names:
|
|
output_names = [f"out{i}" for i in range(len(output_shapes))]
|
|
self._output_names = output_names
|
|
|
|
def forward(self, x) -> Dict[str, Tensor]:
|
|
outs = self.backbone(x)
|
|
if self.neck is not None:
|
|
outs = self.neck(outs)
|
|
assert isinstance(
|
|
outs, (list, tuple)
|
|
), "mmdet backbone should return a list/tuple of tensors!"
|
|
if len(outs) != len(self._output_shapes):
|
|
raise ValueError(
|
|
"Length of output_shapes does not match outputs from the mmdet backbone: "
|
|
f"{len(outs)} != {len(self._output_shapes)}"
|
|
)
|
|
return {k: v for k, v in zip(self._output_names, outs)}
|
|
|
|
def output_shape(self) -> Dict[str, ShapeSpec]:
|
|
return {k: v for k, v in zip(self._output_names, self._output_shapes)}
|
|
|
|
|
|
class MMDetDetector(nn.Module):
|
|
"""
|
|
Wrapper of a mmdetection detector model, for detection and instance segmentation.
|
|
Input/output formats of this class follow detectron2's convention, so a
|
|
mmdetection model can be trained and evaluated in detectron2.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
detector: Union[nn.Module, Mapping],
|
|
*,
|
|
|
|
|
|
size_divisibility=32,
|
|
pixel_mean: Tuple[float],
|
|
pixel_std: Tuple[float],
|
|
):
|
|
"""
|
|
Args:
|
|
detector: a mmdet detector, or a mmdet config dict that defines a detector.
|
|
size_divisibility: pad input images to multiple of this number
|
|
pixel_mean: per-channel mean to normalize input image
|
|
pixel_std: per-channel stddev to normalize input image
|
|
"""
|
|
super().__init__()
|
|
if isinstance(detector, Mapping):
|
|
from mmdet.models import build_detector
|
|
|
|
detector = build_detector(_to_container(detector))
|
|
self.detector = detector
|
|
self.detector.init_weights()
|
|
self.size_divisibility = size_divisibility
|
|
|
|
self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False)
|
|
self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False)
|
|
assert (
|
|
self.pixel_mean.shape == self.pixel_std.shape
|
|
), f"{self.pixel_mean} and {self.pixel_std} have different shapes!"
|
|
|
|
def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
|
|
images = [x["image"].to(self.device) for x in batched_inputs]
|
|
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
|
images = ImageList.from_tensors(images, size_divisibility=self.size_divisibility).tensor
|
|
metas = []
|
|
rescale = {"height" in x for x in batched_inputs}
|
|
if len(rescale) != 1:
|
|
raise ValueError("Some inputs have original height/width, but some don't!")
|
|
rescale = list(rescale)[0]
|
|
output_shapes = []
|
|
for input in batched_inputs:
|
|
meta = {}
|
|
c, h, w = input["image"].shape
|
|
meta["img_shape"] = meta["ori_shape"] = (h, w, c)
|
|
if rescale:
|
|
scale_factor = np.array(
|
|
[w / input["width"], h / input["height"]] * 2, dtype="float32"
|
|
)
|
|
ori_shape = (input["height"], input["width"])
|
|
output_shapes.append(ori_shape)
|
|
meta["ori_shape"] = ori_shape + (c,)
|
|
else:
|
|
scale_factor = 1.0
|
|
output_shapes.append((h, w))
|
|
meta["scale_factor"] = scale_factor
|
|
meta["flip"] = False
|
|
padh, padw = images.shape[-2:]
|
|
meta["pad_shape"] = (padh, padw, c)
|
|
metas.append(meta)
|
|
|
|
if self.training:
|
|
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
|
|
if gt_instances[0].has("gt_masks"):
|
|
from mmdet.core import PolygonMasks as mm_PolygonMasks, BitmapMasks as mm_BitMasks
|
|
|
|
def convert_mask(m, shape):
|
|
|
|
if isinstance(m, BitMasks):
|
|
return mm_BitMasks(m.tensor.cpu().numpy(), shape[0], shape[1])
|
|
else:
|
|
return mm_PolygonMasks(m.polygons, shape[0], shape[1])
|
|
|
|
gt_masks = [convert_mask(x.gt_masks, x.image_size) for x in gt_instances]
|
|
losses_and_metrics = self.detector.forward_train(
|
|
images,
|
|
metas,
|
|
[x.gt_boxes.tensor for x in gt_instances],
|
|
[x.gt_classes for x in gt_instances],
|
|
gt_masks=gt_masks,
|
|
)
|
|
else:
|
|
losses_and_metrics = self.detector.forward_train(
|
|
images,
|
|
metas,
|
|
[x.gt_boxes.tensor for x in gt_instances],
|
|
[x.gt_classes for x in gt_instances],
|
|
)
|
|
return _parse_losses(losses_and_metrics)
|
|
else:
|
|
results = self.detector.simple_test(images, metas, rescale=rescale)
|
|
results = [
|
|
{"instances": _convert_mmdet_result(r, shape)}
|
|
for r, shape in zip(results, output_shapes)
|
|
]
|
|
return results
|
|
|
|
@property
|
|
def device(self):
|
|
return self.pixel_mean.device
|
|
|
|
|
|
|
|
|
|
def _convert_mmdet_result(result, shape: Tuple[int, int]) -> Instances:
|
|
if isinstance(result, tuple):
|
|
bbox_result, segm_result = result
|
|
if isinstance(segm_result, tuple):
|
|
segm_result = segm_result[0]
|
|
else:
|
|
bbox_result, segm_result = result, None
|
|
|
|
bboxes = torch.from_numpy(np.vstack(bbox_result))
|
|
bboxes, scores = bboxes[:, :4], bboxes[:, -1]
|
|
labels = [
|
|
torch.full((bbox.shape[0],), i, dtype=torch.int32) for i, bbox in enumerate(bbox_result)
|
|
]
|
|
labels = torch.cat(labels)
|
|
inst = Instances(shape)
|
|
inst.pred_boxes = Boxes(bboxes)
|
|
inst.scores = scores
|
|
inst.pred_classes = labels
|
|
|
|
if segm_result is not None and len(labels) > 0:
|
|
segm_result = list(itertools.chain(*segm_result))
|
|
segm_result = [torch.from_numpy(x) if isinstance(x, np.ndarray) else x for x in segm_result]
|
|
segm_result = torch.stack(segm_result, dim=0)
|
|
inst.pred_masks = segm_result
|
|
return inst
|
|
|
|
|
|
|
|
def _parse_losses(losses: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
log_vars = OrderedDict()
|
|
for loss_name, loss_value in losses.items():
|
|
if isinstance(loss_value, torch.Tensor):
|
|
log_vars[loss_name] = loss_value.mean()
|
|
elif isinstance(loss_value, list):
|
|
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
|
|
else:
|
|
raise TypeError(f"{loss_name} is not a tensor or list of tensors")
|
|
|
|
if "loss" not in loss_name:
|
|
|
|
storage = get_event_storage()
|
|
value = log_vars.pop(loss_name).cpu().item()
|
|
storage.put_scalar(loss_name, value)
|
|
return log_vars
|
|
|