Spaces:
Build error
Build error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import logging | |
| from abc import ABCMeta, abstractmethod | |
| import torch.nn as nn | |
| # from .utils import load_checkpoint | |
| from mmcv_custom.checkpoint import load_checkpoint | |
| class BaseBackbone(nn.Module, metaclass=ABCMeta): | |
| """Base backbone. | |
| This class defines the basic functions of a backbone. Any backbone that | |
| inherits this class should at least define its own `forward` function. | |
| """ | |
| def init_weights(self, pretrained=None, patch_padding='pad', part_features=None): | |
| """Init backbone weights. | |
| Args: | |
| pretrained (str | None): If pretrained is a string, then it | |
| initializes backbone weights by loading the pretrained | |
| checkpoint. If pretrained is None, then it follows default | |
| initializer or customized initializer in subclasses. | |
| """ | |
| if isinstance(pretrained, str): | |
| logger = logging.getLogger() | |
| load_checkpoint(self, pretrained, strict=False, logger=logger, patch_padding=patch_padding, part_features=part_features) | |
| elif pretrained is None: | |
| # use default initializer or customized initializer in subclasses | |
| pass | |
| else: | |
| raise TypeError('pretrained must be a str or None.' | |
| f' But received {type(pretrained)}.') | |
| def forward(self, x): | |
| """Forward function. | |
| Args: | |
| x (Tensor | tuple[Tensor]): x could be a torch.Tensor or a tuple of | |
| torch.Tensor, containing input data for forward computation. | |
| """ | |