bkhmsi's picture
initialized repo
d36d50b
raw
history blame
4.32 kB
from typing import (
Optional,
)
import math
import torch as T
from torch import nn
from torch.nn import functional as F
import opt_einsum as oe
from torch import Tensor
einsum = oe.contract
def masked_softmax(xs: Tensor, mask: Tensor, dim: int = -1, eps=1e-12):
xs = xs.masked_fill(~mask, -1e9)
xs = F.softmax(xs, dim=dim)
return xs
class Attention(nn.Module):
def __init__(
self,
kind: str,
query_dim: int,
input_dim: int,
output_dim: int = None,
activation: str = 'auto',
scaled = True,
):
super().__init__()
assert kind in [
'dot',
'linear',
]
self.kind = kind
self.Dq = query_dim
self.Din = input_dim
self.Dout = output_dim or self.Din
self.activation = 'auto'
self.scaled = scaled
self.Wq_ = nn.Linear(self.Dq, self.Din)
self.Wk_ = nn.Linear(self.Din, self.Din)
self.Wv_ = nn.Linear(self.Din, self.Dout)
self.Wz_ = nn.Linear(self.Din, self.Dout)
def forward(
self,
query: Tensor,
data: Tensor,
content_mask: Optional[Tensor] = None,
prejudice_mask: Optional[Tensor] = None,
):
#^ query: [b, ts, tw, dq]
#^ data: [b, ts, di]
#^ content_mask: [b, ts, tw]
#^ prejudice_mask: [b, ts, ts]
#^ => output: [b, ts, tw, dz]
dimB, dimS, dimW, dimI = query.shape
# TODO: Optimize out the [ts, ts, *] intermediate
qs = self.Wq_(query)
ks = self.Wk_(data)
vs = self.Wv_(data)
if content_mask is not None:
words_mask = content_mask.any(2)
#^ words_mask : [b, ts]
else:
words_mask = qs.new_ones((dimB, dimS))
if self.kind == 'linear':
# Ref: https://twitter.com/francoisfleuret/status/1267455240007188486
assert prejudice_mask is None, "Linear mode does not support prejudice_mask."
assert content_mask is not None, "Linear mode requires a content_mask."
qs = T.relu(qs) * content_mask.unsqueeze(3)
#^ qs: [bswi]
ks = T.relu(ks) * words_mask.unsqueeze(2)
#^ ks: [bsi]
vks = einsum("bsi, bsz -> bzi", ks, vs)
#^ vks : [b, dz, di]
zs = einsum("bswi, bzi -> bswz", qs, vks)
#^ zs : [b, ts, tw, dz]
if self.scaled:
ks = ks.sum(1)
#^ ks: [bi]
denom = einsum("bswi, bi -> bsw", qs, ks) + 1e-9
zs = zs / denom
elif self.kind == 'dot':
# Ref: https://arxiv.org/abs/1706.03762
# s=ts in q
# S=ts in ks,vs
att_map = einsum("bqwi, bki -> bqkw", qs, ks)
#^ [b, ts:q, ts:k, tw]
if self.scaled == 'seqlen':
att_map_ndim = len(att_map.shape) - 1
norm_coeff = words_mask.sum(1).view(-1, *([1] * att_map_ndim))
#^ [b, _, _, _]
att_map = att_map / T.sqrt(norm_coeff.float())
else:
att_map = att_map / math.sqrt(self.Din)
if content_mask is None and prejudice_mask is None:
att_map = F.softmax(att_map, dim=2)
else:
if content_mask is None:
assert prejudice_mask is not None # !for mypy
qk_mask = prejudice_mask.unsqueeze(3)
#^ qk_mask : [b, ts:q, ts:k, tw^]
elif prejudice_mask is None:
qk_mask = words_mask.unsqueeze(1).unsqueeze(3) * content_mask.unsqueeze(2)
#^ qk_mask : [b, ts:q, ts:k^, tw]
else:
qk_mask = words_mask.unsqueeze(1).unsqueeze(3)
# qk_mask = words_mask.unsqueeze(1).unsqueeze(3) * content_mask.unsqueeze(2)
qk_mask = qk_mask * prejudice_mask.unsqueeze(3)
#^ qk_mask : [b, ts:q^, ts:k, tw]
att_map = masked_softmax(att_map, qk_mask.bool(), dim=2)
#^ att_map : [b, ts:q, ts:k, tw]
zs = einsum("bqkw, bkz -> bqwz", att_map, vs)
zs = self.Wz_(zs)
return zs, att_map