|
""" PyTorch Mixed Convolution |
|
|
|
Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) |
|
|
|
Hacked together by / Copyright 2020 Ross Wightman |
|
""" |
|
|
|
import torch |
|
from torch import nn as nn |
|
|
|
from .conv2d_same import create_conv2d_pad |
|
|
|
|
|
def _split_channels(num_chan, num_groups): |
|
split = [num_chan // num_groups for _ in range(num_groups)] |
|
split[0] += num_chan - sum(split) |
|
return split |
|
|
|
|
|
class MixedConv2d(nn.ModuleDict): |
|
""" Mixed Grouped Convolution |
|
|
|
Based on MDConv and GroupedConv in MixNet impl: |
|
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py |
|
""" |
|
def __init__(self, in_channels, out_channels, kernel_size=3, |
|
stride=1, padding='', dilation=1, depthwise=False, **kwargs): |
|
super(MixedConv2d, self).__init__() |
|
|
|
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] |
|
num_groups = len(kernel_size) |
|
in_splits = _split_channels(in_channels, num_groups) |
|
out_splits = _split_channels(out_channels, num_groups) |
|
self.in_channels = sum(in_splits) |
|
self.out_channels = sum(out_splits) |
|
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): |
|
conv_groups = in_ch if depthwise else 1 |
|
|
|
self.add_module( |
|
str(idx), |
|
create_conv2d_pad( |
|
in_ch, out_ch, k, stride=stride, |
|
padding=padding, dilation=dilation, groups=conv_groups, **kwargs) |
|
) |
|
self.splits = in_splits |
|
|
|
def forward(self, x): |
|
x_split = torch.split(x, self.splits, 1) |
|
x_out = [c(x_split[i]) for i, c in enumerate(self.values())] |
|
x = torch.cat(x_out, 1) |
|
return x |
|
|