File size: 3,836 Bytes
786f6a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
""" Conv2d + BN + Act

Hacked together by / Copyright 2020 Ross Wightman
"""
import functools
from torch import nn as nn

from .create_conv2d import create_conv2d
from .create_norm_act import get_norm_act_layer


class ConvNormAct(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size=1,
            stride=1,
            padding='',
            dilation=1,
            groups=1,
            bias=False,
            apply_act=True,
            norm_layer=nn.BatchNorm2d,
            norm_kwargs=None,
            act_layer=nn.ReLU,
            act_kwargs=None,
            drop_layer=None,
    ):
        super(ConvNormAct, self).__init__()
        norm_kwargs = norm_kwargs or {}
        act_kwargs = act_kwargs or {}

        self.conv = create_conv2d(
            in_channels, out_channels, kernel_size, stride=stride,
            padding=padding, dilation=dilation, groups=groups, bias=bias)

        # NOTE for backwards compatibility with models that use separate norm and act layer definitions
        norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
        # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
        if drop_layer:
            norm_kwargs['drop_layer'] = drop_layer
        self.bn = norm_act_layer(
            out_channels,
            apply_act=apply_act,
            act_kwargs=act_kwargs,
            **norm_kwargs,
        )

    @property
    def in_channels(self):
        return self.conv.in_channels

    @property
    def out_channels(self):
        return self.conv.out_channels

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x


ConvBnAct = ConvNormAct


def create_aa(aa_layer, channels, stride=2, enable=True):
    if not aa_layer or not enable:
        return nn.Identity()
    if isinstance(aa_layer, functools.partial):
        if issubclass(aa_layer.func, nn.AvgPool2d):
            return aa_layer()
        else:
            return aa_layer(channels)
    elif issubclass(aa_layer, nn.AvgPool2d):
        return aa_layer(stride)
    else:
        return aa_layer(channels=channels, stride=stride)


class ConvNormActAa(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size=1,
            stride=1,
            padding='',
            dilation=1,
            groups=1,
            bias=False,
            apply_act=True,
            norm_layer=nn.BatchNorm2d,
            norm_kwargs=None,
            act_layer=nn.ReLU,
            act_kwargs=None,
            aa_layer=None,
            drop_layer=None,
    ):
        super(ConvNormActAa, self).__init__()
        use_aa = aa_layer is not None and stride == 2
        norm_kwargs = norm_kwargs or {}
        act_kwargs = act_kwargs or {}

        self.conv = create_conv2d(
            in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
            padding=padding, dilation=dilation, groups=groups, bias=bias)

        # NOTE for backwards compatibility with models that use separate norm and act layer definitions
        norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
        # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
        if drop_layer:
            norm_kwargs['drop_layer'] = drop_layer
        self.bn = norm_act_layer(out_channels, apply_act=apply_act, act_kwargs=act_kwargs, **norm_kwargs)
        self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)

    @property
    def in_channels(self):
        return self.conv.in_channels

    @property
    def out_channels(self):
        return self.conv.out_channels

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.aa(x)
        return x