|
from typing import ( |
|
Optional, |
|
) |
|
import math |
|
|
|
import torch as T |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
import opt_einsum as oe |
|
|
|
from torch import Tensor |
|
|
|
einsum = oe.contract |
|
|
|
|
|
def masked_softmax(xs: Tensor, mask: Tensor, dim: int = -1, eps=1e-12): |
|
xs = xs.masked_fill(~mask, -1e9) |
|
xs = F.softmax(xs, dim=dim) |
|
return xs |
|
|
|
class Attention(nn.Module): |
|
def __init__( |
|
self, |
|
kind: str, |
|
query_dim: int, |
|
input_dim: int, |
|
output_dim: int = None, |
|
activation: str = 'auto', |
|
scaled = True, |
|
): |
|
super().__init__() |
|
assert kind in [ |
|
'dot', |
|
'linear', |
|
] |
|
|
|
self.kind = kind |
|
self.Dq = query_dim |
|
self.Din = input_dim |
|
self.Dout = output_dim or self.Din |
|
self.activation = 'auto' |
|
self.scaled = scaled |
|
|
|
self.Wq_ = nn.Linear(self.Dq, self.Din) |
|
self.Wk_ = nn.Linear(self.Din, self.Din) |
|
self.Wv_ = nn.Linear(self.Din, self.Dout) |
|
self.Wz_ = nn.Linear(self.Din, self.Dout) |
|
|
|
def forward( |
|
self, |
|
query: Tensor, |
|
data: Tensor, |
|
content_mask: Optional[Tensor] = None, |
|
prejudice_mask: Optional[Tensor] = None, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
dimB, dimS, dimW, dimI = query.shape |
|
|
|
|
|
qs = self.Wq_(query) |
|
ks = self.Wk_(data) |
|
vs = self.Wv_(data) |
|
|
|
if content_mask is not None: |
|
words_mask = content_mask.any(2) |
|
|
|
else: |
|
words_mask = qs.new_ones((dimB, dimS)) |
|
|
|
if self.kind == 'linear': |
|
|
|
assert prejudice_mask is None, "Linear mode does not support prejudice_mask." |
|
assert content_mask is not None, "Linear mode requires a content_mask." |
|
qs = T.relu(qs) * content_mask.unsqueeze(3) |
|
|
|
ks = T.relu(ks) * words_mask.unsqueeze(2) |
|
|
|
vks = einsum("bsi, bsz -> bzi", ks, vs) |
|
|
|
zs = einsum("bswi, bzi -> bswz", qs, vks) |
|
|
|
if self.scaled: |
|
ks = ks.sum(1) |
|
|
|
denom = einsum("bswi, bi -> bsw", qs, ks) + 1e-9 |
|
zs = zs / denom |
|
|
|
elif self.kind == 'dot': |
|
|
|
|
|
|
|
att_map = einsum("bqwi, bki -> bqkw", qs, ks) |
|
|
|
if self.scaled == 'seqlen': |
|
att_map_ndim = len(att_map.shape) - 1 |
|
norm_coeff = words_mask.sum(1).view(-1, *([1] * att_map_ndim)) |
|
|
|
att_map = att_map / T.sqrt(norm_coeff.float()) |
|
else: |
|
att_map = att_map / math.sqrt(self.Din) |
|
|
|
if content_mask is None and prejudice_mask is None: |
|
att_map = F.softmax(att_map, dim=2) |
|
else: |
|
if content_mask is None: |
|
assert prejudice_mask is not None |
|
qk_mask = prejudice_mask.unsqueeze(3) |
|
|
|
elif prejudice_mask is None: |
|
qk_mask = words_mask.unsqueeze(1).unsqueeze(3) * content_mask.unsqueeze(2) |
|
|
|
else: |
|
qk_mask = words_mask.unsqueeze(1).unsqueeze(3) |
|
|
|
qk_mask = qk_mask * prejudice_mask.unsqueeze(3) |
|
|
|
|
|
att_map = masked_softmax(att_map, qk_mask.bool(), dim=2) |
|
|
|
|
|
zs = einsum("bqkw, bkz -> bqwz", att_map, vs) |
|
|
|
zs = self.Wz_(zs) |
|
return zs, att_map |