File size: 4,882 Bytes
786f6a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
""" Activations
A collection of activations fn and modules with a common interface so that they can
easily be swapped. All have an `inplace` arg even if not used.
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from torch import nn as nn
from torch.nn import functional as F
def swish(x, inplace: bool = False):
"""Swish - Described in: https://arxiv.org/abs/1710.05941
"""
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
class Swish(nn.Module):
def __init__(self, inplace: bool = False):
super(Swish, self).__init__()
self.inplace = inplace
def forward(self, x):
return swish(x, self.inplace)
def mish(x, inplace: bool = False):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
NOTE: I don't have a working inplace variant
"""
return x.mul(F.softplus(x).tanh())
class Mish(nn.Module):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
"""
def __init__(self, inplace: bool = False):
super(Mish, self).__init__()
def forward(self, x):
return mish(x)
def sigmoid(x, inplace: bool = False):
return x.sigmoid_() if inplace else x.sigmoid()
# PyTorch has this, but not with a consistent inplace argmument interface
class Sigmoid(nn.Module):
def __init__(self, inplace: bool = False):
super(Sigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
return x.sigmoid_() if self.inplace else x.sigmoid()
def tanh(x, inplace: bool = False):
return x.tanh_() if inplace else x.tanh()
# PyTorch has this, but not with a consistent inplace argmument interface
class Tanh(nn.Module):
def __init__(self, inplace: bool = False):
super(Tanh, self).__init__()
self.inplace = inplace
def forward(self, x):
return x.tanh_() if self.inplace else x.tanh()
def hard_swish(x, inplace: bool = False):
inner = F.relu6(x + 3.).div_(6.)
return x.mul_(inner) if inplace else x.mul(inner)
class HardSwish(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSwish, self).__init__()
self.inplace = inplace
def forward(self, x):
return hard_swish(x, self.inplace)
def hard_sigmoid(x, inplace: bool = False):
if inplace:
return x.add_(3.).clamp_(0., 6.).div_(6.)
else:
return F.relu6(x + 3.) / 6.
class HardSigmoid(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
return hard_sigmoid(x, self.inplace)
def hard_mish(x, inplace: bool = False):
""" Hard Mish
Experimental, based on notes by Mish author Diganta Misra at
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
"""
if inplace:
return x.mul_(0.5 * (x + 2).clamp(min=0, max=2))
else:
return 0.5 * x * (x + 2).clamp(min=0, max=2)
class HardMish(nn.Module):
def __init__(self, inplace: bool = False):
super(HardMish, self).__init__()
self.inplace = inplace
def forward(self, x):
return hard_mish(x, self.inplace)
class PReLU(nn.PReLU):
"""Applies PReLU (w/ dummy inplace arg)
"""
def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None:
super(PReLU, self).__init__(num_parameters=num_parameters, init=init)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.prelu(input, self.weight)
def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
return F.gelu(x)
class GELU(nn.Module):
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
"""
def __init__(self, inplace: bool = False):
super(GELU, self).__init__()
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.gelu(input)
def gelu_tanh(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
return F.gelu(x, approximate='tanh')
class GELUTanh(nn.Module):
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
"""
def __init__(self, inplace: bool = False):
super(GELUTanh, self).__init__()
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.gelu(input, approximate='tanh')
def quick_gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
return x * torch.sigmoid(1.702 * x)
class QuickGELU(nn.Module):
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
"""
def __init__(self, inplace: bool = False):
super(QuickGELU, self).__init__()
def forward(self, input: torch.Tensor) -> torch.Tensor:
return quick_gelu(input)
|