Spaces:
Runtime error
Runtime error
File size: 14,867 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 |
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import List, Optional
import torch.nn as nn
import torch.nn.functional as F
from mmengine.logging import print_log
from torch import Tensor
from mmseg.registry import MODELS
from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig,
OptSampleList, SampleList, add_prefix)
from .base import BaseSegmentor
@MODELS.register_module()
class EncoderDecoder(BaseSegmentor):
"""Encoder Decoder segmentors.
EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
Note that auxiliary_head is only used for deep supervision during training,
which could be dumped during inference.
1. The ``loss`` method is used to calculate the loss of model,
which includes two steps: (1) Extracts features to obtain the feature maps
(2) Call the decode head loss function to forward decode head model and
calculate losses.
.. code:: text
loss(): extract_feat() -> _decode_head_forward_train() -> _auxiliary_head_forward_train (optional)
_decode_head_forward_train(): decode_head.loss()
_auxiliary_head_forward_train(): auxiliary_head.loss (optional)
2. The ``predict`` method is used to predict segmentation results,
which includes two steps: (1) Run inference function to obtain the list of
seg_logits (2) Call post-processing function to obtain list of
``SegDataSample`` including ``pred_sem_seg`` and ``seg_logits``.
.. code:: text
predict(): inference() -> postprocess_result()
infercen(): whole_inference()/slide_inference()
whole_inference()/slide_inference(): encoder_decoder()
encoder_decoder(): extract_feat() -> decode_head.predict()
3. The ``_forward`` method is used to output the tensor by running the model,
which includes two steps: (1) Extracts features to obtain the feature maps
(2)Call the decode head forward function to forward decode head model.
.. code:: text
_forward(): extract_feat() -> _decode_head.forward()
Args:
backbone (ConfigType): The config for the backnone of segmentor.
decode_head (ConfigType): The config for the decode head of segmentor.
neck (OptConfigType): The config for the neck of segmentor.
Defaults to None.
auxiliary_head (OptConfigType): The config for the auxiliary head of
segmentor. Defaults to None.
train_cfg (OptConfigType): The config for training. Defaults to None.
test_cfg (OptConfigType): The config for testing. Defaults to None.
data_preprocessor (dict, optional): The pre-process config of
:class:`BaseDataPreprocessor`.
pretrained (str, optional): The path for pretrained model.
Defaults to None.
init_cfg (dict, optional): The weight initialized config for
:class:`BaseModule`.
""" # noqa: E501
def __init__(self,
backbone: ConfigType,
decode_head: ConfigType,
neck: OptConfigType = None,
auxiliary_head: OptConfigType = None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
pretrained: Optional[str] = None,
init_cfg: OptMultiConfig = None):
super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
if pretrained is not None:
assert backbone.get('pretrained') is None, \
'both backbone and segmentor set pretrained weight'
backbone.pretrained = pretrained
self.backbone = MODELS.build(backbone)
if neck is not None:
self.neck = MODELS.build(neck)
self._init_decode_head(decode_head)
self._init_auxiliary_head(auxiliary_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
assert self.with_decode_head
def _init_decode_head(self, decode_head: ConfigType) -> None:
"""Initialize ``decode_head``"""
self.decode_head = MODELS.build(decode_head)
self.align_corners = self.decode_head.align_corners
self.num_classes = self.decode_head.num_classes
self.out_channels = self.decode_head.out_channels
def _init_auxiliary_head(self, auxiliary_head: ConfigType) -> None:
"""Initialize ``auxiliary_head``"""
if auxiliary_head is not None:
if isinstance(auxiliary_head, list):
self.auxiliary_head = nn.ModuleList()
for head_cfg in auxiliary_head:
self.auxiliary_head.append(MODELS.build(head_cfg))
else:
self.auxiliary_head = MODELS.build(auxiliary_head)
def extract_feat(self, inputs: Tensor) -> List[Tensor]:
"""Extract features from images."""
x = self.backbone(inputs)
if self.with_neck:
x = self.neck(x)
return x
def encode_decode(self, inputs: Tensor,
batch_img_metas: List[dict]) -> Tensor:
"""Encode images with backbone and decode into a semantic segmentation
map of the same size as input."""
x = self.extract_feat(inputs)
seg_logits = self.decode_head.predict(x, batch_img_metas,
self.test_cfg)
return seg_logits
def _decode_head_forward_train(self, inputs: List[Tensor],
data_samples: SampleList) -> dict:
"""Run forward function and calculate loss for decode head in
training."""
losses = dict()
loss_decode = self.decode_head.loss(inputs, data_samples,
self.train_cfg)
losses.update(add_prefix(loss_decode, 'decode'))
return losses
def _auxiliary_head_forward_train(self, inputs: List[Tensor],
data_samples: SampleList) -> dict:
"""Run forward function and calculate loss for auxiliary head in
training."""
losses = dict()
if isinstance(self.auxiliary_head, nn.ModuleList):
for idx, aux_head in enumerate(self.auxiliary_head):
loss_aux = aux_head.loss(inputs, data_samples, self.train_cfg)
losses.update(add_prefix(loss_aux, f'aux_{idx}'))
else:
loss_aux = self.auxiliary_head.loss(inputs, data_samples,
self.train_cfg)
losses.update(add_prefix(loss_aux, 'aux'))
return losses
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
inputs (Tensor): Input images.
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
It usually includes information such as `metainfo` and
`gt_sem_seg`.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self.extract_feat(inputs)
losses = dict()
loss_decode = self._decode_head_forward_train(x, data_samples)
losses.update(loss_decode)
if self.with_auxiliary_head:
loss_aux = self._auxiliary_head_forward_train(x, data_samples)
losses.update(loss_aux)
return losses
def predict(self,
inputs: Tensor,
data_samples: OptSampleList = None) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (List[:obj:`SegDataSample`], optional): The seg data
samples. It usually includes information such as `metainfo`
and `gt_sem_seg`.
Returns:
list[:obj:`SegDataSample`]: Segmentation results of the
input images. Each SegDataSample usually contain:
- ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
- ``seg_logits``(PixelData): Predicted logits of semantic
segmentation before normalization.
"""
if data_samples is not None:
batch_img_metas = [
data_sample.metainfo for data_sample in data_samples
]
else:
batch_img_metas = [
dict(
ori_shape=inputs.shape[2:],
img_shape=inputs.shape[2:],
pad_shape=inputs.shape[2:],
padding_size=[0, 0, 0, 0])
] * inputs.shape[0]
seg_logits = self.inference(inputs, batch_img_metas)
return self.postprocess_result(seg_logits, data_samples)
def _forward(self,
inputs: Tensor,
data_samples: OptSampleList = None) -> Tensor:
"""Network forward process.
Args:
inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (List[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_sem_seg`.
Returns:
Tensor: Forward output of model without any post-processes.
"""
x = self.extract_feat(inputs)
return self.decode_head.forward(x)
def slide_inference(self, inputs: Tensor,
batch_img_metas: List[dict]) -> Tensor:
"""Inference by sliding-window with overlap.
If h_crop > h_img or w_crop > w_img, the small patch will be used to
decode without padding.
Args:
inputs (tensor): the tensor should have a shape NxCxHxW,
which contains all images in the batch.
batch_img_metas (List[dict]): List of image metainfo where each may
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', and 'pad_shape'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
Returns:
Tensor: The segmentation results, seg_logits from model of each
input image.
"""
h_stride, w_stride = self.test_cfg.stride
h_crop, w_crop = self.test_cfg.crop_size
batch_size, _, h_img, w_img = inputs.size()
out_channels = self.out_channels
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = inputs[:, :, y1:y2, x1:x2]
# change the image shape to patch shape
batch_img_metas[0]['img_shape'] = crop_img.shape[2:]
# the output of encode_decode is seg logits tensor map
# with shape [N, C, H, W]
crop_seg_logit = self.encode_decode(crop_img, batch_img_metas)
preds += F.pad(crop_seg_logit,
(int(x1), int(preds.shape[3] - x2), int(y1),
int(preds.shape[2] - y2)))
count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
seg_logits = preds / count_mat
return seg_logits
def whole_inference(self, inputs: Tensor,
batch_img_metas: List[dict]) -> Tensor:
"""Inference with full image.
Args:
inputs (Tensor): The tensor should have a shape NxCxHxW, which
contains all images in the batch.
batch_img_metas (List[dict]): List of image metainfo where each may
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', and 'pad_shape'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
Returns:
Tensor: The segmentation results, seg_logits from model of each
input image.
"""
seg_logits = self.encode_decode(inputs, batch_img_metas)
return seg_logits
def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
"""Inference with slide/whole style.
Args:
inputs (Tensor): The input image of shape (N, 3, H, W).
batch_img_metas (List[dict]): List of image metainfo where each may
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', 'pad_shape', and 'padding_size'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
Returns:
Tensor: The segmentation results, seg_logits from model of each
input image.
"""
assert self.test_cfg.get('mode', 'whole') in ['slide', 'whole'], \
f'Only "slide" or "whole" test mode are supported, but got ' \
f'{self.test_cfg["mode"]}.'
ori_shape = batch_img_metas[0]['ori_shape']
if not all(_['ori_shape'] == ori_shape for _ in batch_img_metas):
print_log(
'Image shapes are different in the batch.',
logger='current',
level=logging.WARN)
if self.test_cfg.mode == 'slide':
seg_logit = self.slide_inference(inputs, batch_img_metas)
else:
seg_logit = self.whole_inference(inputs, batch_img_metas)
return seg_logit
def aug_test(self, inputs, batch_img_metas, rescale=True):
"""Test with augmentations.
Only rescale=True is supported.
"""
# aug_test rescale all imgs back to ori_shape for now
assert rescale
# to save memory, we get augmented seg logit inplace
seg_logit = self.inference(inputs[0], batch_img_metas[0], rescale)
for i in range(1, len(inputs)):
cur_seg_logit = self.inference(inputs[i], batch_img_metas[i],
rescale)
seg_logit += cur_seg_logit
seg_logit /= len(inputs)
seg_pred = seg_logit.argmax(dim=1)
# unravel batch dim
seg_pred = list(seg_pred)
return seg_pred
|