Plachta's picture
Upload 69 files
a4d0945 verified
raw
history blame
30.4 kB
import copy
import numbers
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from .activation import MultiheadAttention
from .scaling import ActivationBalancer, BalancedDoubleSwish
from .scaling import BasicNorm as _BasicNorm
from .rotary_embedding import RotaryEmbedding
from .conv import ConvolutionModule, MultiLayeredConv1d
_shape_t = Union[int, List[int], torch.Size]
class LayerNorm(nn.Module):
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
def __init__(
self,
normalized_shape: _shape_t,
eps: float = 1e-5,
elementwise_affine: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return (
F.layer_norm(
input,
self.normalized_shape,
self.weight,
self.bias,
self.eps,
),
embedding,
)
assert embedding is None
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps
)
def extra_repr(self) -> str:
return (
"{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
class AdaptiveLayerNorm(nn.Module):
r"""Adaptive Layer Normalization"""
def __init__(self, d_model, norm) -> None:
super(AdaptiveLayerNorm, self).__init__()
self.project_layer = nn.Linear(d_model, 2 * d_model)
self.norm = norm
self.d_model = d_model
self.eps = self.norm.eps
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1,
)
return (weight * self.norm(input) + bias, embedding)
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1,
)
return weight * self.norm(input) + bias
class BasicNorm(_BasicNorm):
def __init__(
self,
d_model: int,
eps: float = 1e-5,
device=None,
dtype=None,
):
super(BasicNorm, self).__init__(d_model, eps=eps)
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return (
super(BasicNorm, self).forward(input),
embedding,
)
assert embedding is None
return super(BasicNorm, self).forward(input)
class BalancedBasicNorm(nn.Module):
def __init__(
self,
d_model: int,
eps: float = 1e-5,
device=None,
dtype=None,
):
super(BalancedBasicNorm, self).__init__()
self.balancer = ActivationBalancer(
d_model,
channel_dim=-1,
min_positive=0.45,
max_positive=0.55,
max_abs=6.0,
)
self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return self.norm((self.balancer(input), embedding))
assert embedding is None
return self.norm(self.balancer(input))
class IdentityNorm(nn.Module):
def __init__(
self,
d_model: int,
eps: float = 1e-5,
device=None,
dtype=None,
) -> None:
super(IdentityNorm, self).__init__()
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
return input
assert embedding is None
return input
class RMSNorm(nn.Module):
def __init__(self, d, p=-1., eps=1e-8, bias=False):
"""
Root Mean Square Layer Normalization
:param d: model size
:param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
:param eps: epsilon value, default 1e-8
:param bias: whether use bias term for RMSNorm, disabled by
default because RMSNorm doesn't enforce re-centering invariance.
"""
super(RMSNorm, self).__init__()
self.eps = eps
self.d = d
self.p = p
self.bias = bias
self.scale = nn.Parameter(torch.ones(d))
self.register_parameter("scale", self.scale)
if self.bias:
self.offset = nn.Parameter(torch.zeros(d))
self.register_parameter("offset", self.offset)
def forward(self, x):
if self.p < 0. or self.p > 1.:
norm_x = x.norm(2, dim=-1, keepdim=True)
d_x = self.d
else:
partial_size = int(self.d * self.p)
partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)
norm_x = partial_x.norm(2, dim=-1, keepdim=True)
d_x = partial_size
rms_x = norm_x * d_x ** (-1. / 2)
x_normed = x / (rms_x + self.eps)
if self.bias:
return self.scale * x_normed + self.offset
return self.scale * x_normed
class TransformerEncoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
batch_first: bool = False,
norm_first: bool = False,
device=None,
dtype=None,
linear1_self_attention_cls: nn.Module = nn.Linear,
linear2_self_attention_cls: nn.Module = nn.Linear,
linear1_feedforward_cls: nn.Module = nn.Linear,
linear2_feedforward_cls: nn.Module = nn.Linear,
layer_norm_cls: nn.Module = LayerNorm,
layer_norm_eps: float = 1e-5,
adaptive_layer_norm=False,
use_conv_module: bool = False,
use_depth_wise_conv: bool = False,
conv_ignore_prefix_len: int = 0,
cross_attention: bool = False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiheadAttention(
d_model,
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs,
)
if cross_attention:
self.has_cross_attention = True
self.cross_attn = nn.MultiheadAttention(
d_model, nhead, 0.1, batch_first=True
)
self.norm3 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
# Implementation of Feedforward model
self.use_depth_wise_conv = use_depth_wise_conv
self.use_conv_module = use_conv_module
if not use_depth_wise_conv:
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
else:
self.dw_ffn = MultiLayeredConv1d(
in_chans=d_model,
hidden_chans=dim_feedforward,
kernel_size=5,
dropout_rate=dropout,
)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
activation = _get_activation_fn(activation)
elif isinstance(activation, partial):
activation = activation(d_model)
elif activation == BalancedDoubleSwish:
activation = BalancedDoubleSwish(d_model)
self.activation = activation
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if layer_norm_cls == IdentityNorm:
norm2 = BalancedBasicNorm(
d_model, eps=layer_norm_eps, **factory_kwargs
)
else:
norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
if adaptive_layer_norm:
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
else:
self.norm1 = norm1
self.norm2 = norm2
self.rotary_emb = RotaryEmbedding(dim=d_model // nhead)
if use_conv_module:
self.conv_module = ConvolutionModule(
d_model,
kernel_size=31,
activation=activation,
ignore_prefix_len=conv_ignore_prefix_len,
)
self.norm_conv = LayerNorm(d_model) # for the CNN module
if adaptive_layer_norm:
self.norm_conv = AdaptiveLayerNorm(d_model, self.norm_conv)
else:
self.conv_module = None
def __setstate__(self, state):
super(TransformerEncoderLayer, self).__setstate__(state)
if not hasattr(self, "activation"):
self.activation = F.relu
def forward(
self,
src: Tensor,
context: Optional[Tensor] = None,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
use_rope: bool = False,
) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
is_src_tuple = False
if isinstance(src, tuple):
x, stage_embedding = src
is_src_tuple = True
else:
x, stage_embedding = src, None
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(
src_key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
if self.norm_first:
x = x + self._sa_block(
self.norm1(x, stage_embedding),
src_mask,
src_key_padding_mask,
use_rope=use_rope,
)
if self.conv_module is not None:
residual = x
x = self.norm_conv(x, stage_embedding)
x = residual + self.dropout1(self.conv_module(x))
# if self.has_cross_attention:
# x = x + self.cross_attn(
# self.norm3(x, stage_embedding),
# context,
# context,
# attn_mask=src_mask,
# )[0]
x = x + self._ff_block(self.norm2(x, stage_embedding))
else:
x = self.norm1(
x + self._sa_block(x, src_mask, src_key_padding_mask, use_rope=use_rope),
stage_embedding,
)
if self.conv_module is not None:
residual = x
x = residual + self.dropout(self.conv_module(x))
x = self.norm_conv(x, stage_embedding)
x = self.norm2(x + self._ff_block(x), stage_embedding)
if is_src_tuple:
return (x, stage_embedding)
return x
def infer(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
past_kv: Optional[Tensor] = None,
use_cache: bool = False,
use_rope: bool = False,
):
x, stage_embedding = src, None
is_src_tuple = False
if isinstance(src, tuple):
x, stage_embedding = src
is_src_tuple = True
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(
src_key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
if self.norm_first:
x_attn_out, kv = self.self_attn.infer(
self.norm1(x, stage_embedding),
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
need_weights=False,
past_kv=past_kv,
use_cache=use_cache,
use_rope=use_rope,
rope=self.rotary_emb
)
x = x + x_attn_out
x = x + self._ff_block(self.norm2(x, stage_embedding))
if is_src_tuple:
return (x, stage_embedding)
return (x, kv)
# self-attention block
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
use_rope: bool = False,
) -> Tensor:
x = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
use_rope=use_rope,
rope=self.rotary_emb
)[0]
return self.dropout1(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
if self.use_depth_wise_conv:
x = self.dw_ffn(x)
else:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
class TransformerEncoder(nn.Module):
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
Args:
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
enable_nested_tensor: if True, input will automatically convert to nested tensor
(and convert back on output). This will improve the overall performance of
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
Examples::
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
src: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
return_layer_states: bool = False,
use_rope: bool = False,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
return_layer_states: return layers' state (optional).
Shape:
see the docs in Transformer class.
"""
if return_layer_states:
layer_states = [] # layers' output
output = src
for mod in self.layers:
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
use_rope=use_rope,
)
layer_states.append(output[0])
if self.norm is not None:
output = self.norm(output)
return layer_states, output
output = src
for mod in self.layers:
output = mod(
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, use_rope=use_rope
)
if self.norm is not None:
output = self.norm(output)
return output
def infer(
self,
src: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
return_layer_states: bool = False,
past_kv: Optional[Tensor] = None,
use_cache: bool = False,
use_rope: bool = False,
):
if past_kv is None:
past_length = 0
past_kv = tuple([None] * self.num_layers)
else:
past_length = past_kv[0][0].size(-2)
new_kv = () if use_cache else None
output = src
for mod, past_layer_kv in zip(self.layers, past_kv):
output, kv = mod.infer(
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past_kv=past_layer_kv, use_cache=use_cache, use_rope=use_rope
)
if use_cache:
new_kv = new_kv + (kv,)
if self.norm is not None:
output = self.norm(output)
return output, new_kv
class TransformerDecoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
linear1_self_attention_cls: nn.Module = nn.Linear,
linear2_self_attention_cls: nn.Module = nn.Linear,
linear1_feedforward_cls: nn.Module = nn.Linear,
linear2_feedforward_cls: nn.Module = nn.Linear,
batch_first: bool = False,
norm_first: bool = False,
device=None,
dtype=None,
layer_norm_cls: nn.Module = LayerNorm,
layer_norm_eps: float = 1e-5,
adaptive_layer_norm=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerDecoderLayer, self).__init__()
self.self_attn = MultiheadAttention(
d_model,
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs,
)
self.multihead_attn = MultiheadAttention(
d_model,
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs,
)
# Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
self.activation = _get_activation_fn(activation)
elif isinstance(activation, partial):
self.activation = activation(d_model)
elif activation == BalancedDoubleSwish:
self.activation = BalancedDoubleSwish(d_model)
else:
self.activation = activation
if adaptive_layer_norm:
norm1 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
norm3 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
self.norm3 = AdaptiveLayerNorm(d_model, norm3)
else:
self.norm1 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
self.norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
if layer_norm_cls == IdentityNorm:
self.norm3 = BalancedBasicNorm(
d_model, eps=layer_norm_eps, **factory_kwargs
)
else:
self.norm3 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
self.rotary_emb = RotaryEmbedding(dim=d_model // nhead)
def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
use_rope: bool = False,
) -> Tensor:
r"""Pass the inputs (and mask) through the decoder layer.
Args:
tgt: the sequence to the decoder layer (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
tgt_is_tuple = False
if isinstance(tgt, tuple):
x, stage_embedding = tgt
tgt_is_tuple = True
else:
x, stage_embedding = tgt, None
if self.norm_first:
x = x + self._sa_block(
self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask, use_rope=use_rope,
)
x_mha_out, attn_map = self._mha_block(
self.norm2(x, stage_embedding),
memory,
memory_mask,
memory_key_padding_mask,
use_rope=use_rope,
)
x = x + x_mha_out
x = x + self._ff_block(self.norm3(x, stage_embedding))
else:
x = self.norm1(
x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
stage_embedding,
)
x = self.norm2(
x
+ self._mha_block(
x, memory, memory_mask, memory_key_padding_mask
),
stage_embedding,
)
x = self.norm3(x + self._ff_block(x), stage_embedding)
if tgt_is_tuple:
return (x, stage_embedding)
return x, attn_map
# self-attention block
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
use_rope: bool = False,
) -> Tensor:
x = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
use_rope=use_rope,
rope=self.rotary_emb
)[0]
return self.dropout1(x)
# multihead attention block
def _mha_block(
self,
x: Tensor,
mem: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
use_rope: bool = False,
) -> Tensor:
x = self.multihead_attn(
x,
mem,
mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
use_rope=use_rope,
rope=self.rotary_emb
)[0]
x, attn_map = x
return self.dropout2(x[0]), attn_map
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout3(x)
class TransformerDecoder(nn.Module):
r"""TransformerDecoder is a stack of N decoder layers. Users can build the
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
Args:
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
num_layers: the number of sub-decoder-layers in the decoder (required).
norm: the layer normalization component (optional).
enable_nested_tensor: if True, input will automatically convert to nested tensor
(and convert back on output). This will improve the overall performance of
TransformerDecoder when padding rate is high. Default: ``True`` (enabled).
Examples::
>>> decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8)
>>> transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6)
>>> tgt = torch.rand(10, 32, 512)
>>> memory = torch.rand(20, 32, 512)
>>> out = transformer_decoder(tgt, memory)
"""
__constants__ = ["norm"]
def __init__(self, decoder_layer, num_layers, norm=None):
super(TransformerDecoder, self).__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
return_attn: bool = False,
use_rope: bool = False,
) -> Tensor:
r"""Pass the inputs (and mask) through the decoder layers in turn.
Args:
tgt: the sequence to the decoder (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).
return_attn: return cross attention maps of each layer (optional).
Shape:
see the docs in Transformer class.
"""
attn_maps = []
output = tgt
for mod in self.layers:
output, attn_map = mod(
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
use_rope=use_rope,
)
if return_attn:
attn_maps.append(attn_map)
if self.norm is not None:
output = self.norm(output)
return output, attn_maps
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation)
)