josedolot commited on
Commit
5e85cf4
·
1 Parent(s): 5b4e116

Upload encoders/_base.py

Browse files
Files changed (1) hide show
  1. encoders/_base.py +53 -0
encoders/_base.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import List
4
+ from collections import OrderedDict
5
+
6
+ from . import _utils as utils
7
+
8
+
9
+ class EncoderMixin:
10
+ """Add encoder functionality such as:
11
+ - output channels specification of feature tensors (produced by encoder)
12
+ - patching first convolution for arbitrary input channels
13
+ """
14
+
15
+ @property
16
+ def out_channels(self):
17
+ """Return channels dimensions for each tensor of forward output of encoder"""
18
+ return self._out_channels[: self._depth + 1]
19
+
20
+ def set_in_channels(self, in_channels, pretrained=True):
21
+ """Change first convolution channels"""
22
+ if in_channels == 3:
23
+ return
24
+
25
+ self._in_channels = in_channels
26
+ if self._out_channels[0] == 3:
27
+ self._out_channels = tuple([in_channels] + list(self._out_channels)[1:])
28
+
29
+ utils.patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained)
30
+
31
+ def get_stages(self):
32
+ """Method should be overridden in encoder"""
33
+ raise NotImplementedError
34
+
35
+ def make_dilated(self, output_stride):
36
+
37
+ if output_stride == 16:
38
+ stage_list=[5,]
39
+ dilation_list=[2,]
40
+
41
+ elif output_stride == 8:
42
+ stage_list=[4, 5]
43
+ dilation_list=[2, 4]
44
+
45
+ else:
46
+ raise ValueError("Output stride should be 16 or 8, got {}.".format(output_stride))
47
+
48
+ stages = self.get_stages()
49
+ for stage_indx, dilation_rate in zip(stage_list, dilation_list):
50
+ utils.replace_strides_with_dilation(
51
+ module=stages[stage_indx],
52
+ dilation_rate=dilation_rate,
53
+ )