Spaces:
Paused
Paused
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py | |
from typing import Optional | |
from typing import Tuple | |
import torch | |
from torch import Tensor | |
from torch.nn import Linear | |
from torch.nn import Module | |
from torch.nn.init import constant_ | |
from torch.nn.init import xavier_normal_ | |
from torch.nn.init import xavier_uniform_ | |
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear | |
from torch.nn.parameter import Parameter | |
from torch.nn import functional as F | |
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched | |
class MultiheadAttention(Module): | |
__constants__ = ["batch_first"] | |
bias_k: Optional[torch.Tensor] | |
bias_v: Optional[torch.Tensor] | |
def __init__( | |
self, | |
embed_dim, | |
num_heads, | |
dropout=0.0, | |
bias=True, | |
add_bias_kv=False, | |
add_zero_attn=False, | |
kdim=None, | |
vdim=None, | |
batch_first=False, | |
linear1_cls=Linear, | |
linear2_cls=Linear, | |
device=None, | |
dtype=None, | |
) -> None: | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super(MultiheadAttention, self).__init__() | |
self.embed_dim = embed_dim | |
self.kdim = kdim if kdim is not None else embed_dim | |
self.vdim = vdim if vdim is not None else embed_dim | |
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim | |
self.num_heads = num_heads | |
self.dropout = dropout | |
self.batch_first = batch_first | |
self.head_dim = embed_dim // num_heads | |
assert ( | |
self.head_dim * num_heads == self.embed_dim | |
), "embed_dim must be divisible by num_heads" | |
if add_bias_kv: | |
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) | |
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) | |
else: | |
self.bias_k = self.bias_v = None | |
if linear1_cls == Linear: | |
if not self._qkv_same_embed_dim: | |
self.q_proj_weight = Parameter( | |
torch.empty((embed_dim, embed_dim), **factory_kwargs) | |
) | |
self.k_proj_weight = Parameter( | |
torch.empty((embed_dim, self.kdim), **factory_kwargs) | |
) | |
self.v_proj_weight = Parameter( | |
torch.empty((embed_dim, self.vdim), **factory_kwargs) | |
) | |
self.register_parameter("in_proj_weight", None) | |
else: | |
self.in_proj_weight = Parameter( | |
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) | |
) | |
self.register_parameter("q_proj_weight", None) | |
self.register_parameter("k_proj_weight", None) | |
self.register_parameter("v_proj_weight", None) | |
if bias: | |
self.in_proj_bias = Parameter( | |
torch.empty(3 * embed_dim, **factory_kwargs) | |
) | |
else: | |
self.register_parameter("in_proj_bias", None) | |
self.out_proj = NonDynamicallyQuantizableLinear( | |
embed_dim, embed_dim, bias=bias, **factory_kwargs | |
) | |
self._reset_parameters() | |
else: | |
if not self._qkv_same_embed_dim: | |
raise NotImplementedError | |
else: | |
self.in_proj_linear = linear1_cls( | |
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs | |
) | |
self.in_proj_weight = self.in_proj_linear.weight | |
self.register_parameter("q_proj_weight", None) | |
self.register_parameter("k_proj_weight", None) | |
self.register_parameter("v_proj_weight", None) | |
if bias: | |
self.in_proj_bias = self.in_proj_linear.bias | |
else: | |
self.register_parameter("in_proj_bias", None) | |
self.out_proj = linear2_cls( | |
embed_dim, embed_dim, bias=bias, **factory_kwargs | |
) | |
if self.bias_k is not None: | |
xavier_normal_(self.bias_k) | |
if self.bias_v is not None: | |
xavier_normal_(self.bias_v) | |
self.add_zero_attn = add_zero_attn | |
def _reset_parameters(self): | |
if self._qkv_same_embed_dim: | |
xavier_uniform_(self.in_proj_weight) | |
else: | |
xavier_uniform_(self.q_proj_weight) | |
xavier_uniform_(self.k_proj_weight) | |
xavier_uniform_(self.v_proj_weight) | |
if self.in_proj_bias is not None: | |
constant_(self.in_proj_bias, 0.0) | |
constant_(self.out_proj.bias, 0.0) | |
if self.bias_k is not None: | |
xavier_normal_(self.bias_k) | |
if self.bias_v is not None: | |
xavier_normal_(self.bias_v) | |
def __setstate__(self, state): | |
# Support loading old MultiheadAttention checkpoints generated by v1.1.0 | |
if "_qkv_same_embed_dim" not in state: | |
state["_qkv_same_embed_dim"] = True | |
super(MultiheadAttention, self).__setstate__(state) | |
def forward( | |
self, | |
query: Tensor, | |
key: Tensor, | |
value: Tensor, | |
key_padding_mask: Optional[Tensor] = None, | |
need_weights: bool = True, | |
attn_mask: Optional[Tensor] = None, | |
average_attn_weights: bool = True, | |
cache=None, | |
) -> Tuple[Tensor, Optional[Tensor]]: | |
any_nested = query.is_nested or key.is_nested or value.is_nested | |
query = key = value = query.transpose(1, 0) | |
attn_output = multi_head_attention_forward_patched( | |
query, | |
key, | |
value, | |
self.embed_dim, | |
self.num_heads, | |
self.in_proj_weight, | |
self.in_proj_bias, | |
self.bias_k, | |
self.bias_v, | |
self.add_zero_attn, | |
self.dropout, | |
self.out_proj.weight, | |
self.out_proj.bias, | |
training=self.training, | |
key_padding_mask=key_padding_mask, | |
need_weights=need_weights, | |
attn_mask=attn_mask, | |
average_attn_weights=average_attn_weights, | |
cache=cache, | |
) | |
return attn_output.transpose(1, 0) | |