Spaces:
Running
Running
"""This code is taken from <https://github.com/alexandre01/deepsvg> | |
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte | |
from the paper >https://arxiv.org/pdf/2007.11301.pdf> | |
""" | |
import torch | |
from torch.nn import Linear | |
from torch.nn.init import xavier_uniform_ | |
from torch.nn.init import constant_ | |
from torch.nn.init import xavier_normal_ | |
from torch.nn.parameter import Parameter | |
from torch.nn.modules.module import Module | |
from .functional import multi_head_attention_forward | |
class MultiheadAttention(Module): | |
r"""Allows the model to jointly attend to information | |
from different representation subspaces. | |
See reference: Attention Is All You Need | |
.. math:: | |
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O | |
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) | |
Args: | |
embed_dim: total dimension of the model. | |
num_heads: parallel attention heads. | |
dropout: a Dropout layer on attn_output_weights. Default: 0.0. | |
bias: add bias as module parameter. Default: True. | |
add_bias_kv: add bias to the key and value sequences at dim=0. | |
add_zero_attn: add a new batch of zeros to the key and | |
value sequences at dim=1. | |
kdim: total number of features in key. Default: None. | |
vdim: total number of features in key. Default: None. | |
Note: if kdim and vdim are None, they will be set to embed_dim such that | |
query, key, and value have the same number of features. | |
Examples:: | |
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) | |
>>> attn_output, attn_output_weights = multihead_attn(query, key, value) | |
""" | |
__annotations__ = { | |
'bias_k': torch._jit_internal.Optional[torch.Tensor], | |
'bias_v': torch._jit_internal.Optional[torch.Tensor], | |
} | |
__constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight'] | |
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): | |
super(MultiheadAttention, self).__init__() | |
self.embed_dim = embed_dim | |
self.kdim = kdim if kdim is not None else embed_dim | |
self.vdim = vdim if vdim is not None else embed_dim | |
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim | |
self.num_heads = num_heads | |
self.dropout = dropout | |
self.head_dim = embed_dim // num_heads | |
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" | |
if self._qkv_same_embed_dim is False: | |
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) | |
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) | |
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) | |
self.register_parameter('in_proj_weight', None) | |
else: | |
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) | |
self.register_parameter('q_proj_weight', None) | |
self.register_parameter('k_proj_weight', None) | |
self.register_parameter('v_proj_weight', None) | |
if bias: | |
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) | |
else: | |
self.register_parameter('in_proj_bias', None) | |
self.out_proj = Linear(embed_dim, embed_dim, bias=bias) | |
if add_bias_kv: | |
self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) | |
self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) | |
else: | |
self.bias_k = self.bias_v = None | |
self.add_zero_attn = add_zero_attn | |
self._reset_parameters() | |
def _reset_parameters(self): | |
if self._qkv_same_embed_dim: | |
xavier_uniform_(self.in_proj_weight) | |
else: | |
xavier_uniform_(self.q_proj_weight) | |
xavier_uniform_(self.k_proj_weight) | |
xavier_uniform_(self.v_proj_weight) | |
if self.in_proj_bias is not None: | |
constant_(self.in_proj_bias, 0.) | |
constant_(self.out_proj.bias, 0.) | |
if self.bias_k is not None: | |
xavier_normal_(self.bias_k) | |
if self.bias_v is not None: | |
xavier_normal_(self.bias_v) | |
def __setstate__(self, state): | |
# Support loading old MultiheadAttention checkpoints generated by v1.1.0 | |
if '_qkv_same_embed_dim' not in state: | |
state['_qkv_same_embed_dim'] = True | |
super(MultiheadAttention, self).__setstate__(state) | |
def forward(self, query, key, value, key_padding_mask=None, | |
need_weights=True, attn_mask=None): | |
# type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] | |
r""" | |
Args: | |
query, key, value: map a query and a set of key-value pairs to an output. | |
See "Attention Is All You Need" for more details. | |
key_padding_mask: if provided, specified padding elements in the key will | |
be ignored by the attention. This is an binary mask. When the value is True, | |
the corresponding value on the attention layer will be filled with -inf. | |
need_weights: output attn_output_weights. | |
attn_mask: 2D or 3D mask that prevents attention to certain positions. This is an additive mask | |
(i.e. the values will be added to the attention layer). A 2D mask will be broadcasted for all | |
the batches while a 3D mask allows to specify a different mask for the entries of each batch. | |
Shape: | |
- Inputs: | |
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is | |
the embedding dimension. | |
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is | |
the embedding dimension. | |
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is | |
the embedding dimension. | |
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. | |
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. | |
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, | |
S is the source sequence length. | |
- Outputs: | |
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, | |
E is the embedding dimension. | |
- attn_output_weights: :math:`(N, L, S)` where N is the batch size, | |
L is the target sequence length, S is the source sequence length. | |
""" | |
if not self._qkv_same_embed_dim: | |
return multi_head_attention_forward( | |
query, key, value, self.embed_dim, self.num_heads, | |
self.in_proj_weight, self.in_proj_bias, | |
self.bias_k, self.bias_v, self.add_zero_attn, | |
self.dropout, self.out_proj.weight, self.out_proj.bias, | |
training=self.training, | |
key_padding_mask=key_padding_mask, need_weights=need_weights, | |
attn_mask=attn_mask, use_separate_proj_weight=True, | |
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, | |
v_proj_weight=self.v_proj_weight) | |
else: | |
return multi_head_attention_forward( | |
query, key, value, self.embed_dim, self.num_heads, | |
self.in_proj_weight, self.in_proj_bias, | |
self.bias_k, self.bias_v, self.add_zero_attn, | |
self.dropout, self.out_proj.weight, self.out_proj.bias, | |
training=self.training, | |
key_padding_mask=key_padding_mask, need_weights=need_weights, | |
attn_mask=attn_mask) | |