|
from __future__ import print_function |
|
import tensorflow as tf |
|
import numpy as np |
|
|
|
|
|
class MultiHeadAttention(tf.keras.layers.Layer): |
|
""" Attention Layer - multi-head scaled dot product attention (for encoder and decoder) |
|
|
|
Args: |
|
num_heads: number of attention heads which will be computed in parallel |
|
d_model: embedding size of output features |
|
|
|
Call arguments: |
|
q: query, shape (..., seq_len_q, depth_q) |
|
k: key, shape == (..., seq_len_k, depth_k) |
|
v: value, shape == (..., seq_len_v, depth_v) |
|
mask: Float tensor with shape broadcastable to (..., seq_len_q, seq_len_k) or None. |
|
|
|
Since we use scaled-product attention, we assume seq_len_k = seq_len_v |
|
|
|
Returns: |
|
attention outputs of shape (batch_size, seq_len_q, d_model) |
|
""" |
|
|
|
def __init__(self, n_heads, d_model, **kwargs): |
|
super().__init__(**kwargs) |
|
self.n_heads = n_heads |
|
self.d_model = d_model |
|
self.head_depth = self.d_model // self.n_heads |
|
|
|
if self.d_model % self.n_heads != 0: |
|
raise ValueError("number of heads must divide d_model") |
|
|
|
|
|
self.wq = tf.keras.layers.Dense(self.d_model, use_bias=False) |
|
self.wk = tf.keras.layers.Dense(self.d_model, use_bias=False) |
|
self.wv = tf.keras.layers.Dense(self.d_model, use_bias=False) |
|
|
|
self.w_out = tf.keras.layers.Dense(self.d_model, use_bias=False) |
|
|
|
def split_heads(self, tensor, batch_size): |
|
"""Function for computing attention on several heads simultaneously |
|
Splits last dimension of a tensor into (num_heads, head_depth). |
|
Then we transpose it as (batch_size, num_heads, ..., head_depth) so that we can use broadcast |
|
""" |
|
tensor = tf.reshape(tensor, (batch_size, -1, self.n_heads, self.head_depth)) |
|
return tf.transpose(tensor, perm=[0, 2, 1, 3]) |
|
|
|
|
|
def call(self, q, k, v, mask=None): |
|
|
|
batch_size = tf.shape(q)[0] |
|
|
|
|
|
Q = self.wq(q) |
|
K = self.wk(k) |
|
V = self.wv(v) |
|
|
|
|
|
Q = self.split_heads(Q, batch_size) |
|
K = self.split_heads(K, batch_size) |
|
V = self.split_heads(V, batch_size) |
|
|
|
|
|
compatibility = tf.matmul(Q, K, transpose_b=True) |
|
|
|
|
|
|
|
|
|
dk = tf.cast(tf.shape(K)[-1], tf.float32) |
|
compatibility = compatibility / tf.math.sqrt(dk) |
|
|
|
if mask is not None: |
|
|
|
|
|
|
|
|
|
mask = mask[:, tf.newaxis, :, :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
compatibility = tf.where(mask, |
|
tf.ones_like(compatibility) * (-np.inf), |
|
compatibility) |
|
|
|
compatibility = tf.nn.softmax(compatibility, axis=-1) |
|
|
|
|
|
compatibility = tf.where(tf.math.is_nan(compatibility), tf.zeros_like(compatibility), compatibility) |
|
|
|
|
|
attention = tf.matmul(compatibility, V) |
|
|
|
|
|
attention = tf.transpose(attention, perm=[0, 2, 1, 3]) |
|
|
|
|
|
attention = tf.reshape(attention, (batch_size, -1, self.d_model)) |
|
|
|
|
|
|
|
|
|
|
|
output = self.w_out(attention) |
|
|
|
return output |