|
""" Median Pool |
|
Hacked together by / Copyright 2020 Ross Wightman |
|
""" |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from .helpers import to_2tuple, to_4tuple |
|
|
|
|
|
class MedianPool2d(nn.Module): |
|
""" Median pool (usable as median filter when stride=1) module. |
|
|
|
Args: |
|
kernel_size: size of pooling kernel, int or 2-tuple |
|
stride: pool stride, int or 2-tuple |
|
padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad |
|
same: override padding and enforce same padding, boolean |
|
""" |
|
def __init__(self, kernel_size=3, stride=1, padding=0, same=False): |
|
super(MedianPool2d, self).__init__() |
|
self.k = to_2tuple(kernel_size) |
|
self.stride = to_2tuple(stride) |
|
self.padding = to_4tuple(padding) |
|
self.same = same |
|
|
|
def _padding(self, x): |
|
if self.same: |
|
ih, iw = x.size()[2:] |
|
if ih % self.stride[0] == 0: |
|
ph = max(self.k[0] - self.stride[0], 0) |
|
else: |
|
ph = max(self.k[0] - (ih % self.stride[0]), 0) |
|
if iw % self.stride[1] == 0: |
|
pw = max(self.k[1] - self.stride[1], 0) |
|
else: |
|
pw = max(self.k[1] - (iw % self.stride[1]), 0) |
|
pl = pw // 2 |
|
pr = pw - pl |
|
pt = ph // 2 |
|
pb = ph - pt |
|
padding = (pl, pr, pt, pb) |
|
else: |
|
padding = self.padding |
|
return padding |
|
|
|
def forward(self, x): |
|
x = F.pad(x, self._padding(x), mode='reflect') |
|
x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) |
|
x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] |
|
return x |
|
|