File size: 6,309 Bytes
786f6a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
""" PyTorch selectable adaptive pooling
Adaptive pooling with the ability to select the type of pooling from:
* 'avg' - Average pooling
* 'max' - Max pooling
* 'avgmax' - Sum of average and max pooling re-scaled by 0.5
* 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim
Both a functional and a nn.Module version of the pooling is provided.
Hacked together by / Copyright 2020 Ross Wightman
"""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from .format import get_spatial_dim, get_channel_dim
_int_tuple_2_t = Union[int, Tuple[int, int]]
def adaptive_pool_feat_mult(pool_type='avg'):
if pool_type.endswith('catavgmax'):
return 2
else:
return 1
def adaptive_avgmax_pool2d(x, output_size: _int_tuple_2_t = 1):
x_avg = F.adaptive_avg_pool2d(x, output_size)
x_max = F.adaptive_max_pool2d(x, output_size)
return 0.5 * (x_avg + x_max)
def adaptive_catavgmax_pool2d(x, output_size: _int_tuple_2_t = 1):
x_avg = F.adaptive_avg_pool2d(x, output_size)
x_max = F.adaptive_max_pool2d(x, output_size)
return torch.cat((x_avg, x_max), 1)
def select_adaptive_pool2d(x, pool_type='avg', output_size: _int_tuple_2_t = 1):
"""Selectable global pooling function with dynamic input kernel size
"""
if pool_type == 'avg':
x = F.adaptive_avg_pool2d(x, output_size)
elif pool_type == 'avgmax':
x = adaptive_avgmax_pool2d(x, output_size)
elif pool_type == 'catavgmax':
x = adaptive_catavgmax_pool2d(x, output_size)
elif pool_type == 'max':
x = F.adaptive_max_pool2d(x, output_size)
else:
assert False, 'Invalid pool type: %s' % pool_type
return x
class FastAdaptiveAvgPool(nn.Module):
def __init__(self, flatten: bool = False, input_fmt: F = 'NCHW'):
super(FastAdaptiveAvgPool, self).__init__()
self.flatten = flatten
self.dim = get_spatial_dim(input_fmt)
def forward(self, x):
return x.mean(self.dim, keepdim=not self.flatten)
class FastAdaptiveMaxPool(nn.Module):
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
super(FastAdaptiveMaxPool, self).__init__()
self.flatten = flatten
self.dim = get_spatial_dim(input_fmt)
def forward(self, x):
return x.amax(self.dim, keepdim=not self.flatten)
class FastAdaptiveAvgMaxPool(nn.Module):
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
super(FastAdaptiveAvgMaxPool, self).__init__()
self.flatten = flatten
self.dim = get_spatial_dim(input_fmt)
def forward(self, x):
x_avg = x.mean(self.dim, keepdim=not self.flatten)
x_max = x.amax(self.dim, keepdim=not self.flatten)
return 0.5 * x_avg + 0.5 * x_max
class FastAdaptiveCatAvgMaxPool(nn.Module):
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
super(FastAdaptiveCatAvgMaxPool, self).__init__()
self.flatten = flatten
self.dim_reduce = get_spatial_dim(input_fmt)
if flatten:
self.dim_cat = 1
else:
self.dim_cat = get_channel_dim(input_fmt)
def forward(self, x):
x_avg = x.mean(self.dim_reduce, keepdim=not self.flatten)
x_max = x.amax(self.dim_reduce, keepdim=not self.flatten)
return torch.cat((x_avg, x_max), self.dim_cat)
class AdaptiveAvgMaxPool2d(nn.Module):
def __init__(self, output_size: _int_tuple_2_t = 1):
super(AdaptiveAvgMaxPool2d, self).__init__()
self.output_size = output_size
def forward(self, x):
return adaptive_avgmax_pool2d(x, self.output_size)
class AdaptiveCatAvgMaxPool2d(nn.Module):
def __init__(self, output_size: _int_tuple_2_t = 1):
super(AdaptiveCatAvgMaxPool2d, self).__init__()
self.output_size = output_size
def forward(self, x):
return adaptive_catavgmax_pool2d(x, self.output_size)
class SelectAdaptivePool2d(nn.Module):
"""Selectable global pooling layer with dynamic input kernel size
"""
def __init__(
self,
output_size: _int_tuple_2_t = 1,
pool_type: str = 'fast',
flatten: bool = False,
input_fmt: str = 'NCHW',
):
super(SelectAdaptivePool2d, self).__init__()
assert input_fmt in ('NCHW', 'NHWC')
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
if not pool_type:
self.pool = nn.Identity() # pass through
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
elif pool_type.startswith('fast') or input_fmt != 'NCHW':
assert output_size == 1, 'Fast pooling and non NCHW input formats require output_size == 1.'
if pool_type.endswith('catavgmax'):
self.pool = FastAdaptiveCatAvgMaxPool(flatten, input_fmt=input_fmt)
elif pool_type.endswith('avgmax'):
self.pool = FastAdaptiveAvgMaxPool(flatten, input_fmt=input_fmt)
elif pool_type.endswith('max'):
self.pool = FastAdaptiveMaxPool(flatten, input_fmt=input_fmt)
else:
self.pool = FastAdaptiveAvgPool(flatten, input_fmt=input_fmt)
self.flatten = nn.Identity()
else:
assert input_fmt == 'NCHW'
if pool_type == 'avgmax':
self.pool = AdaptiveAvgMaxPool2d(output_size)
elif pool_type == 'catavgmax':
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
elif pool_type == 'max':
self.pool = nn.AdaptiveMaxPool2d(output_size)
else:
self.pool = nn.AdaptiveAvgPool2d(output_size)
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
def is_identity(self):
return not self.pool_type
def forward(self, x):
x = self.pool(x)
x = self.flatten(x)
return x
def feat_mult(self):
return adaptive_pool_feat_mult(self.pool_type)
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'pool_type=' + self.pool_type \
+ ', flatten=' + str(self.flatten) + ')'
|