josedolot commited on
Commit
1590106
·
1 Parent(s): 4270e7b

Upload encoders/_utils.py

Browse files
Files changed (1) hide show
  1. encoders/_utils.py +59 -0
encoders/_utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True):
6
+ """Change first convolution layer input channels.
7
+ In case:
8
+ in_channels == 1 or in_channels == 2 -> reuse original weights
9
+ in_channels > 3 -> make random kaiming normal initialization
10
+ """
11
+
12
+ # get first conv
13
+ for module in model.modules():
14
+ if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels:
15
+ break
16
+
17
+ weight = module.weight.detach()
18
+ module.in_channels = new_in_channels
19
+
20
+ if not pretrained:
21
+ module.weight = nn.parameter.Parameter(
22
+ torch.Tensor(
23
+ module.out_channels,
24
+ new_in_channels // module.groups,
25
+ *module.kernel_size
26
+ )
27
+ )
28
+ module.reset_parameters()
29
+
30
+ elif new_in_channels == 1:
31
+ new_weight = weight.sum(1, keepdim=True)
32
+ module.weight = nn.parameter.Parameter(new_weight)
33
+
34
+ else:
35
+ new_weight = torch.Tensor(
36
+ module.out_channels,
37
+ new_in_channels // module.groups,
38
+ *module.kernel_size
39
+ )
40
+
41
+ for i in range(new_in_channels):
42
+ new_weight[:, i] = weight[:, i % default_in_channels]
43
+
44
+ new_weight = new_weight * (default_in_channels / new_in_channels)
45
+ module.weight = nn.parameter.Parameter(new_weight)
46
+
47
+
48
+ def replace_strides_with_dilation(module, dilation_rate):
49
+ """Patch Conv2d modules replacing strides with dilation"""
50
+ for mod in module.modules():
51
+ if isinstance(mod, nn.Conv2d):
52
+ mod.stride = (1, 1)
53
+ mod.dilation = (dilation_rate, dilation_rate)
54
+ kh, kw = mod.kernel_size
55
+ mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate)
56
+
57
+ # Kostyl for EfficientNet
58
+ if hasattr(mod, "static_padding"):
59
+ mod.static_padding = nn.Identity()