|
import tensorflow as tf |
|
from tensorflow import keras |
|
from tensorflow.keras import regularizers |
|
import numpy as np |
|
import tensorflow_probability as tfp |
|
|
|
|
|
|
|
output_dim = 256 |
|
reg = 0.01 |
|
|
|
def Coupling(input_shape): |
|
input = keras.layers.Input(shape=input_shape) |
|
|
|
t_layer_1 = keras.layers.Dense( |
|
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) |
|
)(input) |
|
t_layer_2 = keras.layers.Dense( |
|
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) |
|
)(t_layer_1) |
|
t_layer_3 = keras.layers.Dense( |
|
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) |
|
)(t_layer_2) |
|
t_layer_4 = keras.layers.Dense( |
|
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) |
|
)(t_layer_3) |
|
t_layer_5 = keras.layers.Dense( |
|
input_shape, activation="linear", kernel_regularizer=regularizers.l2(reg) |
|
)(t_layer_4) |
|
|
|
s_layer_1 = keras.layers.Dense( |
|
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) |
|
)(input) |
|
s_layer_2 = keras.layers.Dense( |
|
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) |
|
)(s_layer_1) |
|
s_layer_3 = keras.layers.Dense( |
|
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) |
|
)(s_layer_2) |
|
s_layer_4 = keras.layers.Dense( |
|
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) |
|
)(s_layer_3) |
|
s_layer_5 = keras.layers.Dense( |
|
input_shape, activation="tanh", kernel_regularizer=regularizers.l2(reg) |
|
)(s_layer_4) |
|
|
|
return keras.Model(inputs=input, outputs=[s_layer_5, t_layer_5]) |
|
|
|
|
|
class RealNVP(keras.Model): |
|
def __init__(self, num_coupling_layers): |
|
super(RealNVP, self).__init__() |
|
|
|
self.num_coupling_layers = num_coupling_layers |
|
|
|
|
|
self.distribution = tfp.distributions.MultivariateNormalDiag( |
|
loc=[0.0, 0.0], scale_diag=[1.0, 1.0] |
|
) |
|
self.masks = np.array( |
|
[[0, 1], [1, 0]] * (num_coupling_layers // 2), dtype="float32" |
|
) |
|
self.loss_tracker = keras.metrics.Mean(name="loss") |
|
self.layers_list = [Coupling(2) for i in range(num_coupling_layers)] |
|
|
|
@property |
|
def metrics(self): |
|
"""List of the model's metrics. |
|
We make sure the loss tracker is listed as part of `model.metrics` |
|
so that `fit()` and `evaluate()` are able to `reset()` the loss tracker |
|
at the start of each epoch and at the start of an `evaluate()` call. |
|
""" |
|
return [self.loss_tracker] |
|
|
|
def call(self, x, training=True): |
|
log_det_inv = 0 |
|
direction = 1 |
|
if training: |
|
direction = -1 |
|
for i in range(self.num_coupling_layers)[::direction]: |
|
x_masked = x * self.masks[i] |
|
reversed_mask = 1 - self.masks[i] |
|
s, t = self.layers_list[i](x_masked) |
|
s *= reversed_mask |
|
t *= reversed_mask |
|
gate = (direction - 1) / 2 |
|
x = ( |
|
reversed_mask |
|
* (x * tf.exp(direction * s) + direction * t * tf.exp(gate * s)) |
|
+ x_masked |
|
) |
|
log_det_inv += gate * tf.reduce_sum(s, [1]) |
|
|
|
return x, log_det_inv |
|
|
|
|
|
|
|
def log_loss(self, x): |
|
y, logdet = self(x) |
|
log_likelihood = self.distribution.log_prob(y) + logdet |
|
return -tf.reduce_mean(log_likelihood) |
|
|
|
def train_step(self, data): |
|
with tf.GradientTape() as tape: |
|
|
|
loss = self.log_loss(data) |
|
|
|
g = tape.gradient(loss, self.trainable_variables) |
|
self.optimizer.apply_gradients(zip(g, self.trainable_variables)) |
|
self.loss_tracker.update_state(loss) |
|
|
|
return {"loss": self.loss_tracker.result()} |
|
|
|
def test_step(self, data): |
|
loss = self.log_loss(data) |
|
self.loss_tracker.update_state(loss) |
|
|
|
return {"loss": self.loss_tracker.result()} |
|
|
|
def load_model(): |
|
return RealNVP(num_coupling_layers=6) |
|
|