Spaces:
Runtime error
Runtime error
# Copyright (c) OpenCD. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmseg.models.decode_heads.decode_head import BaseDecodeHead | |
from mmseg.models.losses import accuracy | |
from mmseg.models.utils import resize | |
from opencd.registry import MODELS | |
class IdentityHead(BaseDecodeHead): | |
"""Identity Head.""" | |
def __init__(self, **kwargs): | |
super().__init__(channels=1, **kwargs) | |
delattr(self, 'conv_seg') | |
def init_weights(self): | |
pass | |
def _forward_feature(self, inputs): | |
""" | |
Args: | |
inputs (list[Tensor]): List of multi-level img features. | |
Returns: | |
feats (Tensor): A tensor of shape (batch_size, self.channels, | |
H, W) which is feature map for last layer of decoder head. | |
""" | |
x = self._transform_inputs(inputs) | |
return x | |
def forward(self, inputs): | |
"""Forward function.""" | |
output = self._forward_feature(inputs) | |
return output | |
class DSIdentityHead(BaseDecodeHead): | |
"""Deep Supervision Identity Head.""" | |
def __init__(self, **kwargs): | |
super().__init__(channels=1, **kwargs) | |
delattr(self, 'conv_seg') | |
def init_weights(self): | |
pass | |
def _forward_feature(self, inputs): | |
""" | |
Args: | |
inputs (list[Tensor]): List of multi-level img features. | |
Returns: | |
feats (Tensor): A tensor of shape (batch_size, self.channels, | |
H, W) which is feature map for last layer of decoder head. | |
""" | |
x = self._transform_inputs(inputs) | |
return x | |
def forward(self, inputs): | |
"""Forward function.""" | |
output = self._forward_feature(inputs) | |
return output | |
def loss_by_feat(self, seg_logits, batch_data_samples): | |
"""Compute segmentation loss. | |
Args: | |
seg_logits (Tensor): The output from decode head forward function. | |
batch_data_samples (List[:obj:`SegDataSample`]): The seg | |
data samples. It usually includes information such | |
as `metainfo` and `gt_sem_seg`. | |
Returns: | |
dict[str, Tensor]: a dictionary of loss components | |
""" | |
seg_label = self._stack_batch_gt(batch_data_samples) | |
loss = dict() | |
seg_label_size = seg_label.shape[2:] | |
for seg_idx, single_seg_logit in enumerate(seg_logits): | |
single_seg_logit = resize( | |
input=single_seg_logit, | |
size=seg_label_size, | |
mode='bilinear', | |
align_corners=self.align_corners) | |
if self.sampler is not None: | |
seg_weight = self.sampler.sample(single_seg_logit, seg_label) | |
else: | |
seg_weight = None | |
seg_label = seg_label.squeeze(1) | |
if not isinstance(self.loss_decode, nn.ModuleList): | |
losses_decode = [self.loss_decode] | |
else: | |
losses_decode = self.loss_decode | |
for loss_decode in losses_decode: | |
loss_name = f'aux_{seg_idx}_' + loss_decode.loss_name | |
if loss_decode.loss_name not in loss: | |
loss[loss_name] = loss_decode( | |
single_seg_logit, | |
seg_label, | |
weight=seg_weight, | |
ignore_index=self.ignore_index) | |
else: | |
loss[loss_name] += loss_decode( | |
single_seg_logit, | |
seg_label, | |
weight=seg_weight, | |
ignore_index=self.ignore_index) | |
loss['acc_seg'] = accuracy( | |
single_seg_logit, seg_label, ignore_index=self.ignore_index) | |
return loss | |