Spaces:
Build error
Build error
import tensorflow as tf | |
from keras import backend as K | |
from keras.layers import Input, Dense, Conv2D, Conv2DTranspose, Flatten, Reshape, Lambda, BatchNormalization | |
from keras.models import Model | |
import numpy as np | |
import threading | |
KL = tf.keras.layers | |
def cbam_layer(inputs_tensor=None, ratio=None): | |
"""Source: https://blog.csdn.net/ZXF_1991/article/details/104615942 | |
The channel attention | |
""" | |
channels = K.int_shape(inputs_tensor)[-1] | |
def share_layer(inputs=None): | |
x_ = KL.Conv2D(channels // ratio, (1, 1), strides=1, padding="valid")(inputs) | |
x_ = KL.Activation('relu')(x_) | |
output_share = KL.Conv2D(channels, (1, 1), strides=1, padding="valid")(x_) | |
return output_share | |
x_global_avg_pool = KL.GlobalAveragePooling2D()(inputs_tensor) | |
x_global_avg_pool = KL.Reshape((1, 1, channels))(x_global_avg_pool) | |
x_global_max_pool = KL.GlobalMaxPool2D()(inputs_tensor) | |
x_global_max_pool = KL.Reshape((1, 1, channels))(x_global_max_pool) | |
x_global_avg_pool = share_layer(x_global_avg_pool) | |
x_global_max_pool = share_layer(x_global_max_pool) | |
x = KL.Add()([x_global_avg_pool, x_global_max_pool]) | |
x = KL.Activation('sigmoid')(x) | |
CAM = KL.multiply([inputs_tensor, x]) | |
output = CAM | |
return output | |
def res_cell(x, n_channel=64, stride=1): | |
"""The basic unit in the VAE, cell.""" | |
if stride == -1: | |
# upsample cell | |
skip = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(x) | |
skip = Conv2D(filters=n_channel, kernel_size=(1, 1), strides=1, padding='same')(skip) | |
x = Conv2DTranspose(filters=n_channel, kernel_size=(5, 5), strides=2, padding='same')(x) | |
x = BatchNormalization()(x) | |
x = tf.keras.activations.elu(x) | |
x = Conv2DTranspose(filters=n_channel, kernel_size=(5, 5), padding='same')(x) | |
elif stride == 2: | |
# downsample cell | |
skip = Conv2D(filters=n_channel, kernel_size=(1, 1), strides=2, padding='same')(x) | |
x = Conv2D(filters=n_channel, kernel_size=(5, 5), strides=stride, padding='same')(x) | |
x = BatchNormalization()(x) | |
x = tf.keras.activations.elu(x) | |
x = Conv2D(filters=n_channel, kernel_size=(5, 5), padding='same')(x) | |
else: | |
# preserving cell | |
skip = tf.identity(x) | |
x = Conv2D(filters=n_channel, kernel_size=(5, 5), strides=stride, padding='same')(x) | |
x = BatchNormalization()(x) | |
x = tf.keras.activations.elu(x) | |
x = Conv2D(filters=n_channel, kernel_size=(5, 5), padding='same')(x) | |
x = BatchNormalization()(x) | |
x = cbam_layer(inputs_tensor=x, ratio=8) | |
x = x + skip | |
x = tf.keras.activations.elu(x) | |
return x | |
def res_block(x, n_channel=64, upsample=False, n_cells=2): | |
"""The block is a stack of cells.""" | |
if upsample: | |
x = res_cell(x, n_channel=n_channel, stride=-1) | |
else: | |
x = res_cell(x, n_channel=n_channel, stride=2) | |
for _ in range(n_cells - 1): | |
x = res_cell(x, n_channel=n_channel, stride=1) | |
return x | |
def l1_distance(x1, x2): | |
return tf.reduce_mean(tf.math.abs(x1 - x2)) | |
def l1_log_distance(x1, x2): | |
return tf.reduce_mean(tf.math.abs(tf.math.log(tf.maximum(1e-6, x1)) - tf.math.log(tf.maximum(1e-6, x2)))) | |
img_height = 512 | |
img_width = 256 | |
num_channels = 1 | |
input_shape = (img_height, img_width, num_channels) | |
timbre_dim = 20 | |
n_filters = 64 | |
act = 'elu' | |
def compute_latent(x): | |
"""Re-parameterizing.""" | |
mu, sigma = x | |
batch = K.shape(mu)[0] | |
dim = K.int_shape(mu)[1] | |
eps = K.random_normal(shape=(batch, dim)) | |
return mu + K.exp(sigma / 2) * eps | |
def get_encoder(N2=0, channel_sizes=None): | |
"""Assemble and return the VAE encoder.""" | |
if channel_sizes is None: | |
channel_sizes = [32, 64, 64, 96, 96, 128, 160, 216] | |
encoder_input = Input(shape=input_shape) | |
encoder_conv = res_block(encoder_input, channel_sizes[0], upsample=False, n_cells=1) | |
for c in channel_sizes[1:]: | |
encoder_conv = res_block(encoder_conv, c, upsample=False, n_cells=1 + N2) | |
encoder = Flatten()(encoder_conv) | |
mu_timbre = Dense(timbre_dim)(encoder) | |
sigma_timbre = Dense(timbre_dim)(encoder) | |
latent_vector = Lambda(compute_latent, output_shape=(timbre_dim,))([mu_timbre, sigma_timbre]) | |
kl_loss = -0.5 * (1 + sigma_timbre - tf.square(mu_timbre) - tf.exp(sigma_timbre)) | |
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1)) | |
encoder = Model(encoder_input, [latent_vector, kl_loss]) | |
return encoder | |
def get_decoder(N2=0, N3=8, channel_sizes=None): | |
"""Assemble and return the VAE decoder.""" | |
if channel_sizes is None: | |
channel_sizes = [32, 64, 64, 96, 96, 128, 160, 216] | |
conv_shape = [-1, 2 ** (9 - N3), 2 ** (8 - N3), channel_sizes[-1]] | |
decoder_input = Input(shape=(timbre_dim,)) | |
decoder = Dense(conv_shape[1] * conv_shape[2] * conv_shape[3], activation=act)(decoder_input) | |
decoder_conv = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(decoder) | |
for c in list(reversed(channel_sizes))[1:]: | |
decoder_conv = res_block(decoder_conv, c, upsample=True, n_cells=1 + N2) | |
decoder_conv = Conv2DTranspose(filters=num_channels, kernel_size=5, strides=2, | |
padding='same', activation='sigmoid')(decoder_conv) | |
decoder = Model(decoder_input, decoder_conv) | |
return decoder | |
def VAE(N2=0, N3=8, channel_sizes=None): | |
"""Assemble and return the VAE.""" | |
if channel_sizes is None: | |
channel_sizes = [32, 64, 64, 96, 96, 128, 160, 216] | |
print("Creating model...") | |
assert N2 >= 0, "Please set N2 >= 0" | |
assert N3 >= 1, "Please set 1 <= N3 <= 8" | |
assert N3 <= 8, "Please set 1 <= N3 <= 8" | |
assert N3 == len(channel_sizes), "Please set N3 = len(channel_sizes)" | |
encoder = get_encoder(N2, channel_sizes) | |
decoder = get_decoder(N2, N3, channel_sizes) | |
# encoder = tf.keras.models.load_model(f"encoder_thesis_record_1.h5") | |
# decoder = tf.keras.models.load_model(f"decoder_thesis_record_1.h5") | |
encoder_input1 = Input(shape=input_shape) | |
scalar_input1 = Input(shape=(1,)) | |
embedding_1_timbre, kl_loss = encoder(encoder_input1) | |
reconstruction_1 = decoder(embedding_1_timbre) | |
VAE = Model([encoder_input1, scalar_input1], [reconstruction_1, kl_loss]) | |
# decoder.summary() | |
VAE.summary() | |
return encoder, decoder, VAE | |
def my_thread(data_cache): | |
data_cache.refresh() | |
def train_VAE(vae, encoder, decoder, data_cache, stages, batch_size): | |
"""Train the VAE. | |
Parameters | |
---------- | |
vae: keras.engine.functional.Functional | |
The VAE. | |
encoder: keras.engine.functional.Functional | |
The VAE encoder. | |
decoder: keras.engine.functional.Functional | |
The VAE decoder. | |
data_cache: Data_cache | |
A Data_cache entity that provides training data. | |
stages: Dict | |
Defines the training stages. In each stage, the synthetic data will be refreshed and | |
models will be stored once. | |
Returns | |
------- | |
""" | |
threshold = 1e-0 | |
kl_weight = 100.0 | |
def weighted_binary_cross_entropy_loss(true, pred): | |
b_n = true * tf.math.log(tf.maximum(1e-20, pred)) + (1 - true) * tf.math.log(tf.maximum(1e-20, 1 - pred)) | |
w = tf.maximum(threshold, true) | |
return -tf.reduce_sum(b_n / w) / batch_size | |
def reconstruction_loss(true, pred): | |
reconstruction_loss = weighted_binary_cross_entropy_loss(K.flatten(true), K.flatten(pred)) | |
return K.mean(reconstruction_loss) | |
def kl_loss(true, pred): | |
return pred * kl_weight | |
for stage in stages: | |
threshold = stage["threshold"] | |
kl_weight = stage["kl_weight"] | |
vae.compile(tf.keras.optimizers.Adam(learning_rate=stage["learning_rate"]), loss=[reconstruction_loss, kl_loss]) | |
Input_all = data_cache.get_all_data() | |
n_total = np.shape(Input_all)[0] | |
t = threading.Thread(target=my_thread, args=(data_cache,)) | |
t.start() | |
history = vae.fit([Input_all, np.ones(n_total)], [Input_all, np.ones(n_total)], epochs=stage["n_epoch"], | |
batch_size=batch_size) | |
t.join() | |
encoder.save(f"./models/new_trained_models/encoder.h5") | |
decoder.save(f"./models/new_trained_models/decoder.h5") | |