File size: 1,622 Bytes
786f6a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
""" Create Conv2d Factory Method
Hacked together by / Copyright 2020 Ross Wightman
"""
from .mixed_conv2d import MixedConv2d
from .cond_conv2d import CondConv2d
from .conv2d_same import create_conv2d_pad
def create_conv2d(in_channels, out_channels, kernel_size, **kwargs):
""" Select a 2d convolution implementation based on arguments
Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.
Used extensively by EfficientNet, MobileNetv3 and related networks.
"""
if isinstance(kernel_size, list):
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
if 'groups' in kwargs:
groups = kwargs.pop('groups')
if groups == in_channels:
kwargs['depthwise'] = True
else:
assert groups == 1
# We're going to use only lists for defining the MixedConv2d kernel groups,
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)
else:
depthwise = kwargs.pop('depthwise', False)
# for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0
groups = in_channels if depthwise else kwargs.pop('groups', 1)
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
else:
m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
return m
|