cvrp-model / layers.py
peterkros's picture
Upload 7 files
d6c2e08
from __future__ import print_function
import tensorflow as tf
import numpy as np
class MultiHeadAttention(tf.keras.layers.Layer):
""" Attention Layer - multi-head scaled dot product attention (for encoder and decoder)
Args:
num_heads: number of attention heads which will be computed in parallel
d_model: embedding size of output features
Call arguments:
q: query, shape (..., seq_len_q, depth_q)
k: key, shape == (..., seq_len_k, depth_k)
v: value, shape == (..., seq_len_v, depth_v)
mask: Float tensor with shape broadcastable to (..., seq_len_q, seq_len_k) or None.
Since we use scaled-product attention, we assume seq_len_k = seq_len_v
Returns:
attention outputs of shape (batch_size, seq_len_q, d_model)
"""
def __init__(self, n_heads, d_model, **kwargs):
super().__init__(**kwargs)
self.n_heads = n_heads
self.d_model = d_model
self.head_depth = self.d_model // self.n_heads
if self.d_model % self.n_heads != 0:
raise ValueError("number of heads must divide d_model")
# define weight matrices
self.wq = tf.keras.layers.Dense(self.d_model, use_bias=False) # (d_q, d_model)
self.wk = tf.keras.layers.Dense(self.d_model, use_bias=False) # (d_k, d_model)
self.wv = tf.keras.layers.Dense(self.d_model, use_bias=False) # (d_v, d_model)
self.w_out = tf.keras.layers.Dense(self.d_model, use_bias=False) # (d_model, d_model)
def split_heads(self, tensor, batch_size):
"""Function for computing attention on several heads simultaneously
Splits last dimension of a tensor into (num_heads, head_depth).
Then we transpose it as (batch_size, num_heads, ..., head_depth) so that we can use broadcast
"""
tensor = tf.reshape(tensor, (batch_size, -1, self.n_heads, self.head_depth))
return tf.transpose(tensor, perm=[0, 2, 1, 3])
# treats first parameter q as input, and k, v as parameters, so input_shape=q.shape
def call(self, q, k, v, mask=None):
# shape of q: (batch_size, seq_len_q, d_q)
batch_size = tf.shape(q)[0]
# compute Q = q * w_q, ...
Q = self.wq(q) # (batch_size, seq_len_q, d_q) x (d_q, d_model) --> (batch_size, seq_len_q, d_model)
K = self.wk(k) # ... --> (batch_size, seq_len_k, d_model)
V = self.wv(v) # ... --> (batch_size, seq_len_v, d_model)
# split heads: d_model = num_heads * head_depth + reshape
Q = self.split_heads(Q, batch_size) # (batch_size, num_heads, seq_len_q, head_depth)
K = self.split_heads(K, batch_size) # (batch_size, num_heads, seq_len_k, head_depth)
V = self.split_heads(V, batch_size) # (batch_size, num_heads, seq_len_v, head_depth)
# similarity between context vector Q and key K // self-similarity in case of self-attention
compatibility = tf.matmul(Q, K, transpose_b=True) # (batch_size, num_heads, seq_len_q, seq_len_k)
# seq_len_q = n_nodes for encoder self-attention
# seq_len_q = 1 for decoder context-vector attention
# seq_len_k = n_nodes for both encoder & decoder
# rescaling
dk = tf.cast(tf.shape(K)[-1], tf.float32)
compatibility = compatibility / tf.math.sqrt(dk)
if mask is not None:
# we need to reshape mask:
# (batch_size, seq_len_q, seq_len_k) --> (batch_size, 1, seq_len_q, seq_len_k)
# so that we will be able to do a broadcast:
# (batch_size, num_heads, seq_len_q, seq_len_k) + (batch_size, 1, seq_len_q, seq_len_k)
mask = mask[:, tf.newaxis, :, :]
# we use tf.where since 0*-np.inf returns nan, but not -np.inf
# compatibility = tf.where(
# tf.broadcast_to(mask, compatibility.shape), tf.ones_like(compatibility) * (-np.inf),
# compatibility
# )
compatibility = tf.where(mask,
tf.ones_like(compatibility) * (-np.inf),
compatibility)
compatibility = tf.nn.softmax(compatibility, axis=-1) # (batch_size, num_heads, seq_len_q, seq_len_k)
# Replace NaN by zeros (tf.nn.softmax returns NaNs for masked rows)
compatibility = tf.where(tf.math.is_nan(compatibility), tf.zeros_like(compatibility), compatibility)
# seq_len_k = seq_len_v
attention = tf.matmul(compatibility, V) # (batch_size, num_heads, seq_len_q, head_depth)
# transpose back to (batch_size, seq_len_q, num_heads, head_depth)
attention = tf.transpose(attention, perm=[0, 2, 1, 3])
# concatenate heads (last 2 dimensions)
attention = tf.reshape(attention, (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
# project output to the same dimension
# this is equiv. to sum in the article (project heads with W_o and sum), beacuse of block-matrix multiplication
#e.g. https://math.stackexchange.com/questions/2961550/matrix-block-multiplication-definition-properties-and-applications
output = self.w_out(attention) # (batch_size, seq_len_q, d_model)
return output