Spaces:
Runtime error
Runtime error
Upload encoders/_base.py
Browse files- encoders/_base.py +53 -0
encoders/_base.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import List
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
from . import _utils as utils
|
7 |
+
|
8 |
+
|
9 |
+
class EncoderMixin:
|
10 |
+
"""Add encoder functionality such as:
|
11 |
+
- output channels specification of feature tensors (produced by encoder)
|
12 |
+
- patching first convolution for arbitrary input channels
|
13 |
+
"""
|
14 |
+
|
15 |
+
@property
|
16 |
+
def out_channels(self):
|
17 |
+
"""Return channels dimensions for each tensor of forward output of encoder"""
|
18 |
+
return self._out_channels[: self._depth + 1]
|
19 |
+
|
20 |
+
def set_in_channels(self, in_channels, pretrained=True):
|
21 |
+
"""Change first convolution channels"""
|
22 |
+
if in_channels == 3:
|
23 |
+
return
|
24 |
+
|
25 |
+
self._in_channels = in_channels
|
26 |
+
if self._out_channels[0] == 3:
|
27 |
+
self._out_channels = tuple([in_channels] + list(self._out_channels)[1:])
|
28 |
+
|
29 |
+
utils.patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained)
|
30 |
+
|
31 |
+
def get_stages(self):
|
32 |
+
"""Method should be overridden in encoder"""
|
33 |
+
raise NotImplementedError
|
34 |
+
|
35 |
+
def make_dilated(self, output_stride):
|
36 |
+
|
37 |
+
if output_stride == 16:
|
38 |
+
stage_list=[5,]
|
39 |
+
dilation_list=[2,]
|
40 |
+
|
41 |
+
elif output_stride == 8:
|
42 |
+
stage_list=[4, 5]
|
43 |
+
dilation_list=[2, 4]
|
44 |
+
|
45 |
+
else:
|
46 |
+
raise ValueError("Output stride should be 16 or 8, got {}.".format(output_stride))
|
47 |
+
|
48 |
+
stages = self.get_stages()
|
49 |
+
for stage_indx, dilation_rate in zip(stage_list, dilation_list):
|
50 |
+
utils.replace_strides_with_dilation(
|
51 |
+
module=stages[stage_indx],
|
52 |
+
dilation_rate=dilation_rate,
|
53 |
+
)
|