Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from typing import List | |
from collections import OrderedDict | |
from . import _utils as utils | |
class EncoderMixin: | |
"""Add encoder functionality such as: | |
- output channels specification of feature tensors (produced by encoder) | |
- patching first convolution for arbitrary input channels | |
""" | |
def out_channels(self): | |
"""Return channels dimensions for each tensor of forward output of encoder""" | |
return self._out_channels[: self._depth + 1] | |
def set_in_channels(self, in_channels, pretrained=True): | |
"""Change first convolution channels""" | |
if in_channels == 3: | |
return | |
self._in_channels = in_channels | |
if self._out_channels[0] == 3: | |
self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) | |
utils.patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained) | |
def get_stages(self): | |
"""Method should be overridden in encoder""" | |
raise NotImplementedError | |
def make_dilated(self, output_stride): | |
if output_stride == 16: | |
stage_list=[5,] | |
dilation_list=[2,] | |
elif output_stride == 8: | |
stage_list=[4, 5] | |
dilation_list=[2, 4] | |
else: | |
raise ValueError("Output stride should be 16 or 8, got {}.".format(output_stride)) | |
stages = self.get_stages() | |
for stage_indx, dilation_rate in zip(stage_list, dilation_list): | |
utils.replace_strides_with_dilation( | |
module=stages[stage_indx], | |
dilation_rate=dilation_rate, | |
) | |