import tensorflow as tf import numpy as np import scipy.signal as sps import scipy.special as spspec import tensorflow.keras.backend as K import math from tensorflow.python.keras.utils import conv_utils from tensorflow.keras.layers import Layer, InputSpec, Conv2D, LeakyReLU, Dense, BatchNormalization, Input, Concatenate from tensorflow.keras.layers import Conv2DTranspose, ReLU, Activation, UpSampling2D, Add, Reshape, Multiply from tensorflow.keras.layers import AveragePooling2D, LayerNormalization, GlobalAveragePooling2D, MaxPooling2D, Flatten from tensorflow.keras import initializers, constraints, regularizers from tensorflow.keras.models import Model from tensorflow_addons.layers import InstanceNormalization def sin_activation(x, omega=30): return tf.math.sin(omega * x) class AdaIN(Layer): def __init__(self, **kwargs): super(AdaIN, self).__init__(**kwargs) def build(self, input_shapes): x_shape = input_shapes[0] w_shape = input_shapes[1] self.w_channels = w_shape[-1] self.x_channels = x_shape[-1] self.dense_1 = Dense(self.x_channels) self.dense_2 = Dense(self.x_channels) def call(self, inputs): x, w = inputs ys = tf.reshape(self.dense_1(w), (-1, 1, 1, self.x_channels)) yb = tf.reshape(self.dense_2(w), (-1, 1, 1, self.x_channels)) return ys * x + yb def get_config(self): config = { #'w_channels': self.w_channels, #'x_channels': self.x_channels } base_config = super(AdaIN, self).get_config() return dict(list(base_config.items()) + list(config.items())) class Conv2DMod(Layer): def __init__(self, filters, kernel_size, strides=1, padding='valid', dilation_rate=1, kernel_initializer='glorot_uniform', kernel_regularizer=None, activity_regularizer=None, kernel_constraint=None, demod=True, **kwargs): super(Conv2DMod, self).__init__(**kwargs) self.filters = filters self.rank = 2 self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size') self.strides = conv_utils.normalize_tuple(strides, 2, 'strides') self.padding = conv_utils.normalize_padding(padding) self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2, 'dilation_rate') self.kernel_initializer = initializers.get(kernel_initializer) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.activity_regularizer = regularizers.get(activity_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.demod = demod self.input_spec = [InputSpec(ndim = 4), InputSpec(ndim = 2)] def build(self, input_shape): channel_axis = -1 if input_shape[0][channel_axis] is None: raise ValueError('The channel dimension of the inputs ' 'should be defined. Found `None`.') input_dim = input_shape[0][channel_axis] kernel_shape = self.kernel_size + (input_dim, self.filters) if input_shape[1][-1] != input_dim: raise ValueError('The last dimension of modulation input should be equal to input dimension.') self.kernel = self.add_weight(shape=kernel_shape, initializer=self.kernel_initializer, name='kernel', regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) # Set input spec. self.input_spec = [InputSpec(ndim=4, axes={channel_axis: input_dim}), InputSpec(ndim=2)] self.built = True def call(self, inputs): #To channels last x = tf.transpose(inputs[0], [0, 3, 1, 2]) #Get weight and bias modulations #Make sure w's shape is compatible with self.kernel w = K.expand_dims(K.expand_dims(K.expand_dims(inputs[1], axis = 1), axis = 1), axis = -1) #Add minibatch layer to weights wo = K.expand_dims(self.kernel, axis = 0) #Modulate weights = wo * (w+1) #Demodulate if self.demod: d = K.sqrt(K.sum(K.square(weights), axis=[1,2,3], keepdims = True) + 1e-8) weights = weights / d #Reshape/scale input x = tf.reshape(x, [1, -1, x.shape[2], x.shape[3]]) # Fused => reshape minibatch to convolution groups. w = tf.reshape(tf.transpose(weights, [1, 2, 3, 0, 4]), [weights.shape[1], weights.shape[2], weights.shape[3], -1]) x = tf.nn.conv2d(x, w, strides=self.strides, padding="SAME", data_format="NCHW") # Reshape/scale output. x = tf.reshape(x, [-1, self.filters, x.shape[2], x.shape[3]]) # Fused => reshape convolution groups back to minibatch. x = tf.transpose(x, [0, 2, 3, 1]) return x def compute_output_shape(self, input_shape): space = input_shape[0][1:-1] new_space = [] for i in range(len(space)): new_dim = conv_utils.conv_output_length( space[i], self.kernel_size[i], padding=self.padding, stride=self.strides[i], dilation=self.dilation_rate[i]) new_space.append(new_dim) return (input_shape[0],) + tuple(new_space) + (self.filters,) def get_config(self): config = { 'filters': self.filters, 'kernel_size': self.kernel_size, 'strides': self.strides, 'padding': self.padding, 'dilation_rate': self.dilation_rate, 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'demod': self.demod } base_config = super(Conv2DMod, self).get_config() return dict(list(base_config.items()) + list(config.items())) class CreatePatches(tf.keras.layers.Layer ): def __init__( self , patch_size): super( CreatePatches , self).__init__() self.patch_size = patch_size def call(self, inputs): patches = [] # For square images only ( as inputs.shape[ 1 ] = inputs.shape[ 2 ] ) input_image_size = inputs.shape[ 1 ] for i in range( 0 , input_image_size , self.patch_size ): for j in range( 0 , input_image_size , self.patch_size ): patches.append( inputs[ : , i : i + self.patch_size , j : j + self.patch_size , : ] ) return patches def get_config(self): config = {'patch_size': self.patch_size, } base_config = super(CreatePatches, self).get_config() return dict(list(base_config.items()) + list(config.items())) class SelfAttention(tf.keras.layers.Layer ): def __init__( self , alpha, filters=128): super(SelfAttention , self).__init__() self.alpha = alpha self.filters = filters self.f = Conv2D(filters, 1, 1) self.g = Conv2D(filters, 1, 1) self.s = Conv2D(filters, 1, 1) def call(self, inputs): f_map = self.f(inputs) f_map = tf.image.transpose(f_map) g_map = self.g(inputs) s_map = self.s(inputs) att = f_map * g_map att = att / self.alpha return tf.keras.activations.softmax(att + s_map, axis=0) def get_config(self): config = {'alpha': self.alpha, } base_config = super(SelfAttention, self).get_config() return dict(list(base_config.items()) + list(config.items())) class Sampling(Layer): """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.""" def call(self, inputs): z_mean, z_log_var = inputs batch = tf.shape(z_mean)[0] dim = tf.shape(z_mean)[1] epsilon = tf.keras.backend.random_normal(shape=(batch, dim)) return z_mean + tf.exp(0.5 * z_log_var) * epsilon def get_config(self): config = { } base_config = super(Sampling, self).get_config() return dict(list(base_config.items()) + list(config.items())) class ArcMarginPenaltyLogists(tf.keras.layers.Layer): """ArcMarginPenaltyLogists""" def __init__(self, num_classes, margin=0.7, logist_scale=64, **kwargs): super(ArcMarginPenaltyLogists, self).__init__(**kwargs) self.num_classes = num_classes self.margin = margin self.logist_scale = logist_scale def build(self, input_shape): self.w = self.add_variable( "weights", shape=[int(input_shape[-1]), self.num_classes]) self.cos_m = tf.identity(math.cos(self.margin), name='cos_m') self.sin_m = tf.identity(math.sin(self.margin), name='sin_m') self.th = tf.identity(math.cos(math.pi - self.margin), name='th') self.mm = tf.multiply(self.sin_m, self.margin, name='mm') def call(self, embds, labels): normed_embds = tf.nn.l2_normalize(embds, axis=1, name='normed_embd') normed_w = tf.nn.l2_normalize(self.w, axis=0, name='normed_weights') cos_t = tf.matmul(normed_embds, normed_w, name='cos_t') sin_t = tf.sqrt(1. - cos_t ** 2, name='sin_t') cos_mt = tf.subtract( cos_t * self.cos_m, sin_t * self.sin_m, name='cos_mt') cos_mt = tf.where(cos_t > self.th, cos_mt, cos_t - self.mm) mask = tf.one_hot(tf.cast(labels, tf.int32), depth=self.num_classes, name='one_hot_mask') logists = tf.where(mask == 1., cos_mt, cos_t) logists = tf.multiply(logists, self.logist_scale, 'arcface_logist') return logists def get_config(self): config = {'num_classes': self.num_classes, 'margin': self.margin, 'logist_scale': self.logist_scale } base_config = super(ArcMarginPenaltyLogists, self).get_config() return dict(list(base_config.items()) + list(config.items())) class KLLossLayer(tf.keras.layers.Layer): """ArcMarginPenaltyLogists""" def __init__(self, beta=1.5, **kwargs): super(KLLossLayer, self).__init__(**kwargs) self.beta = beta def call(self, inputs): z_mean, z_log_var = inputs kl_loss = tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) kl_loss = -0.5 * kl_loss * self.beta self.add_loss(kl_loss * 0) self.add_metric(kl_loss, 'kl_loss') return inputs def get_config(self): config = { 'beta': self.beta } base_config = super(KLLossLayer, self).get_config() return dict(list(base_config.items()) + list(config.items())) class ReflectionPadding2D(Layer): def __init__(self, padding=(1, 1), **kwargs): self.padding = tuple(padding) self.input_spec = [InputSpec(ndim=4)] super(ReflectionPadding2D, self).__init__(**kwargs) def compute_output_shape(self, s): return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3]) def call(self, x, mask=None): w_pad,h_pad = self.padding return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT') def get_config(self): config = { 'padding': self.padding, } base_config = super(ReflectionPadding2D, self).get_config() return dict(list(base_config.items()) + list(config.items())) class ResBlock(Layer): def __init__(self, fil, **kwargs): super(ResBlock, self).__init__(**kwargs) self.fil = fil self.conv_0 = Conv2D(kernel_size=3, filters=fil, strides=1) self.conv_1 = Conv2D(kernel_size=3, filters=fil, strides=1) self.res = Conv2D(kernel_size=1, filters=1, strides=1) self.lrelu = LeakyReLU(0.2) self.padding = ReflectionPadding2D(padding=(1, 1)) def call(self, inputs): res = self.res(inputs) x = self.padding(inputs) x = self.conv_0(x) x = self.lrelu(x) x = self.padding(x) x = self.conv_1(x) x = self.lrelu(x) out = x + res return out def get_config(self): config = { 'fil': self.fil } base_config = super(ResBlock, self).get_config() return dict(list(base_config.items()) + list(config.items())) class SubpixelConv2D(Layer): """ Subpixel Conv2D Layer upsampling a layer from (h, w, c) to (h*r, w*r, c/(r*r)), where r is the scaling factor, default to 4 # Arguments upsampling_factor: the scaling factor # Input shape Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model. # Output shape the second and the third dimension increased by a factor of `upsampling_factor`; the last layer decreased by a factor of `upsampling_factor^2`. # References Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network Shi et Al. https://arxiv.org/abs/1609.05158 """ def __init__(self, upsampling_factor=4, **kwargs): super(SubpixelConv2D, self).__init__(**kwargs) self.upsampling_factor = upsampling_factor def build(self, input_shape): last_dim = input_shape[-1] factor = self.upsampling_factor * self.upsampling_factor if last_dim % (factor) != 0: raise ValueError('Channel ' + str(last_dim) + ' should be of ' 'integer times of upsampling_factor^2: ' + str(factor) + '.') def call(self, inputs, **kwargs): return tf.nn.depth_to_space( inputs, self.upsampling_factor ) def get_config(self): config = { 'upsampling_factor': self.upsampling_factor, } base_config = super(SubpixelConv2D, self).get_config() return dict(list(base_config.items()) + list(config.items())) def compute_output_shape(self, input_shape): factor = self.upsampling_factor * self.upsampling_factor input_shape_1 = None if input_shape[1] is not None: input_shape_1 = input_shape[1] * self.upsampling_factor input_shape_2 = None if input_shape[2] is not None: input_shape_2 = input_shape[2] * self.upsampling_factor dims = [ input_shape[0], input_shape_1, input_shape_2, int(input_shape[3]/factor) ] return tuple( dims ) def id_mod_res(inputs, c): feature_map, z_id = inputs x = Conv2D(c, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(feature_map) x = AdaIN()([x, z_id]) x = ReLU()(x) x = Conv2D(c, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) x = AdaIN()([x, z_id]) out = Add()([x, feature_map]) return out def id_mod_res_v2(inputs, c): feature_map, z_id = inputs affine = Dense(feature_map.shape[-1])(z_id) x = Conv2DMod(c, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))([feature_map, affine]) x = ReLU()(x) affine = Dense(x.shape[-1])(z_id) x = Conv2DMod(c, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))([x, affine]) out = Add()([x, feature_map]) x = ReLU()(x) return out def simswap(im_size, filter_scale=1, deep=True): inputs = Input(shape=(im_size, im_size, 3)) z_id = Input(shape=(512,)) x = ReflectionPadding2D(padding=(3, 3))(inputs) x = Conv2D(filters=64 // filter_scale, kernel_size=7, padding='valid', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 112 x = BatchNormalization()(x) x = Activation(tf.keras.activations.relu)(x) x = Conv2D(filters=64 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 56 x = BatchNormalization()(x) x = Activation(tf.keras.activations.relu)(x) x = Conv2D(filters=256 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 28 x = BatchNormalization()(x) x = Activation(tf.keras.activations.relu)(x) x = Conv2D(filters=512 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 14 x = BatchNormalization()(x) x = Activation(tf.keras.activations.relu)(x) # 14 if deep: x = Conv2D(filters=512 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 7 x = BatchNormalization()(x) x = Activation(tf.keras.activations.relu)(x) x = id_mod_res([x, z_id], 512 // filter_scale) x = id_mod_res([x, z_id], 512 // filter_scale) x = id_mod_res([x, z_id], 512 // filter_scale) x = id_mod_res([x, z_id], 512 // filter_scale) if deep: x = SubpixelConv2D(upsampling_factor=2)(x) x = Conv2D(filters=512 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) x = BatchNormalization()(x) x = Activation(tf.keras.activations.relu)(x) x = SubpixelConv2D(upsampling_factor=2)(x) x = Conv2D(filters=256 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) x = BatchNormalization()(x) x = Activation(tf.keras.activations.relu)(x) x = SubpixelConv2D(upsampling_factor=2)(x) x = Conv2D(filters=128 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) x = BatchNormalization()(x) x = Activation(tf.keras.activations.relu)(x) # 56 x = SubpixelConv2D(upsampling_factor=2)(x) x = Conv2D(filters=64 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) x = BatchNormalization()(x) x = Activation(tf.keras.activations.relu)(x) # 112 x = ReflectionPadding2D(padding=(3, 3))(x) out = Conv2D(filters=3, kernel_size=7, padding='valid')(x) # 112 model = Model([inputs, z_id], out) model.summary() return model def simswap_v2(deep=True): inputs = Input(shape=(224, 224, 3)) z_id = Input(shape=(512,)) x = ReflectionPadding2D(padding=(3, 3))(inputs) x = Conv2D(filters=64, kernel_size=7, padding='valid', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 112 x = BatchNormalization()(x) x = Activation(tf.keras.activations.relu)(x) x = Conv2D(filters=64, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 56 x = BatchNormalization()(x) x = Activation(tf.keras.activations.relu)(x) x = Conv2D(filters=256, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 28 x = BatchNormalization()(x) x = Activation(tf.keras.activations.relu)(x) x = Conv2D(filters=512, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 14 x = BatchNormalization()(x) x = Activation(tf.keras.activations.relu)(x) # 14 if deep: x = Conv2D(filters=512, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 7 x = BatchNormalization()(x) x = Activation(tf.keras.activations.relu)(x) x = id_mod_res_v2([x, z_id], 512) x = id_mod_res_v2([x, z_id], 512) x = id_mod_res_v2([x, z_id], 512) x = id_mod_res_v2([x, z_id], 512) x = id_mod_res_v2([x, z_id], 512) x = id_mod_res_v2([x, z_id], 512) if deep: x = UpSampling2D(interpolation='bilinear')(x) x = Conv2D(filters=512, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) x = BatchNormalization()(x) x = Activation(tf.keras.activations.relu)(x) x = UpSampling2D(interpolation='bilinear')(x) x = Conv2D(filters=256, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) x = BatchNormalization()(x) x = Activation(tf.keras.activations.relu)(x) x = UpSampling2D(interpolation='bilinear')(x) x = Conv2D(filters=128, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) x = BatchNormalization()(x) x = Activation(tf.keras.activations.relu)(x) # 56 x = UpSampling2D(interpolation='bilinear')(x) x = Conv2D(filters=64, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) x = BatchNormalization()(x) x = Activation(tf.keras.activations.relu)(x) # 112 x = ReflectionPadding2D(padding=(3, 3))(x) out = Conv2D(filters=3, kernel_size=7, padding='valid')(x) # 112 out = Activation('sigmoid')(out) model = Model([inputs, z_id], out) model.summary() return model class AdaptiveAttention(Layer): def __init__(self, **kwargs): super(AdaptiveAttention, self).__init__(**kwargs) def call(self, inputs): m, a, i = inputs return (1 - m) * a + m * i def get_config(self): base_config = super(AdaptiveAttention, self).get_config() return base_config def aad_block(inputs, c_out): h, z_att, z_id = inputs h_norm = BatchNormalization()(h) h = Conv2D(filters=c_out, kernel_size=1, kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(h_norm) m = Activation('sigmoid')(h) z_att_gamma = Conv2D(filters=c_out, kernel_size=1, kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_att) z_att_beta = Conv2D(filters=c_out, kernel_size=1, kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_att) a = Multiply()([h_norm, z_att_gamma]) a = Add()([a, z_att_beta]) z_id_gamma = Dense(h_norm.shape[-1], kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_id) z_id_gamma = Reshape(target_shape=(1, 1, h_norm.shape[-1]))(z_id_gamma) z_id_beta = Dense(h_norm.shape[-1], kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_id) z_id_beta = Reshape(target_shape=(1, 1, h_norm.shape[-1]))(z_id_beta) i = Multiply()([h_norm, z_id_gamma]) i = Add()([i, z_id_beta]) h_out = AdaptiveAttention()([m, a, i]) return h_out def aad_block_mod(inputs, c_out): h, z_att, z_id = inputs h_norm = BatchNormalization()(h) h = Conv2D(filters=c_out, kernel_size=1, kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(h_norm) m = Activation('sigmoid')(h) z_att_gamma = Conv2D(filters=c_out, kernel_size=1, kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(z_att) z_att_beta = Conv2D(filters=c_out, kernel_size=1, kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(z_att) a = Multiply()([h_norm, z_att_gamma]) a = Add()([a, z_att_beta]) z_id_gamma = Dense(h_norm.shape[-1], kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(z_id) i = Conv2DMod(filters=c_out, kernel_size=1, padding='same', kernel_initializer='he_uniform', kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))([h_norm, z_id_gamma]) h_out = AdaptiveAttention()([m, a, i]) return h_out def aad_res_block(inputs, c_in, c_out): h, z_att, z_id = inputs if c_in == c_out: aad = aad_block([h, z_att, z_id], c_out) act = ReLU()(aad) conv = Conv2D(filters=c_out, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) aad = aad_block([conv, z_att, z_id], c_out) act = ReLU()(aad) conv = Conv2D(filters=c_out, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) h_out = Add()([h, conv]) return h_out else: aad = aad_block([h, z_att, z_id], c_in) act = ReLU()(aad) h_res = Conv2D(filters=c_out, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) aad = aad_block([h, z_att, z_id], c_in) act = ReLU()(aad) conv = Conv2D(filters=c_out, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) aad = aad_block([conv, z_att, z_id], c_out) act = ReLU()(aad) conv = Conv2D(filters=c_out, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) h_out = Add()([h_res, conv]) return h_out def aad_res_block_mod(inputs, c_in, c_out): h, z_att, z_id = inputs if c_in == c_out: aad = aad_block_mod([h, z_att, z_id], c_out) act = ReLU()(aad) conv = Conv2D(filters=c_out, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) aad = aad_block_mod([conv, z_att, z_id], c_out) act = ReLU()(aad) conv = Conv2D(filters=c_out, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) h_out = Add()([h, conv]) return h_out else: aad = aad_block_mod([h, z_att, z_id], c_in) act = ReLU()(aad) h_res = Conv2D(filters=c_out, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) aad = aad_block_mod([h, z_att, z_id], c_in) act = ReLU()(aad) conv = Conv2D(filters=c_out, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) aad = aad_block_mod([conv, z_att, z_id], c_out) act = ReLU()(aad) conv = Conv2D(filters=c_out, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) h_out = Add()([h_res, conv]) return h_out class FilteredReLU(Layer): def __init__(self, critically_sampled, in_channels, out_channels, in_size, out_size, in_sampling_rate, out_sampling_rate, in_cutoff, out_cutoff, in_half_width, out_half_width, conv_kernel = 3, lrelu_upsampling = 2, filter_size = 6, conv_clamp = 256, use_radial_filters = False, is_torgb = False, **kwargs): super(FilteredReLU, self).__init__(**kwargs) self.critically_sampled = critically_sampled self.in_channels = in_channels self.out_channels = out_channels self.in_size = np.broadcast_to(np.asarray(in_size), [2]) self.out_size = np.broadcast_to(np.asarray(out_size), [2]) self.in_sampling_rate = in_sampling_rate self.out_sampling_rate = out_sampling_rate self.in_cutoff = in_cutoff self.out_cutoff = out_cutoff self.in_half_width = in_half_width self.out_half_width = out_half_width self.is_torgb = is_torgb self.conv_kernel = 1 if is_torgb else conv_kernel self.lrelu_upsampling = lrelu_upsampling self.conv_clamp = conv_clamp self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) # Up sampling filter self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate) # Down sampling filter self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 self.d_radial = use_radial_filters and not self.critically_sampled self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.d_radial) # Compute padding pad_total = (self.out_size - 1) * self.d_factor + 1 pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor pad_total += self.u_taps + self.d_taps - 2 pad_lo = (pad_total + self.u_factor) // 2 pad_hi = pad_total - pad_lo self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] self.gain = 1 if self.is_torgb else np.sqrt(2) self.slope = 1 if self.is_torgb else 0.2 self.act_funcs = {'linear': {'func': lambda x, **_: x, 'def_alpha': 0, 'def_gain': 1}, 'lrelu': {'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), 'def_alpha': 0.2, 'def_gain': np.sqrt(2)}, } b_init = tf.zeros_initializer() self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), dtype="float32"), trainable=True) def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): if numtaps == 1: return None if not radial: f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) return f x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs r = np.hypot(*np.meshgrid(x, x)) f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) w = np.kaiser(numtaps, beta) f *= np.outer(w, w) f /= np.sum(f) return f def get_filter_size(self, f): if f is None: return 1, 1 assert 1 <= f.ndim <= 2 return f.shape[-1], f.shape[0] # width, height def parse_padding(self, padding): if isinstance(padding, int): padding = [padding, padding] assert isinstance(padding, (list, tuple)) assert all(isinstance(x, (int, np.integer)) for x in padding) padding = [int(x) for x in padding] if len(padding) == 2: px, py = padding padding = [px, px, py, py] px0, px1, py0, py1 = padding return px0, px1, py0, py1 def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): spec = self.act_funcs[act] alpha = float(alpha if alpha is not None else spec['def_alpha']) gain = float(gain if gain is not None else spec['def_gain']) clamp = float(clamp if clamp is not None else -1) if b is not None: x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) x = spec['func'](x, alpha=alpha) if gain != 1: x = x * gain if clamp >= 0: x = tf.clip_by_value(x, -clamp, clamp) return x def parse_scaling(self, scaling): if isinstance(scaling, int): scaling = [scaling, scaling] sx, sy = scaling assert sx >= 1 and sy >= 1 return sx, sy def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): if f is None: f = tf.ones([1, 1], dtype=tf.float32) batch_size, in_height, in_width, num_channels = x.shape upx, upy = self.parse_scaling(up) downx, downy = self.parse_scaling(down) padx0, padx1, pady0, pady1 = self.parse_padding(padding) upW = in_width * upx + padx0 + padx1 upH = in_height * upy + pady0 + pady1 assert upW >= f.shape[-1] and upH >= f.shape[0] # Channel first format. x = tf.transpose(x, perm=[0, 3, 1, 2]) # Upsample by inserting zeros. x = tf.reshape(x, [num_channels, batch_size, in_height, 1, in_width, 1]) x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) # Pad or crop. x = tf.pad(x, [[0, 0], [0, 0], [tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], [tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) x = x[:, :, tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] # Setup filter. f = f * (gain ** (f.ndim / 2)) f = tf.cast(f, dtype=x.dtype) if not flip_filter: f = tf.reverse(f, axis=[-1]) f = tf.reshape(f, shape=(1, 1, f.shape[-1])) f = tf.repeat(f, repeats=num_channels, axis=0) if tf.rank(f) == 4: f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) x = tf.nn.conv2d(x, f_0, 1, 'VALID') else: f_0 = tf.expand_dims(f, axis=2) f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) f_1 = tf.expand_dims(f, axis=3) f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') x = x[:, :, ::downy, ::downx] # Back to channel last. x = tf.transpose(x, perm=[0, 2, 3, 1]) return x def filtered_lrelu(self, x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): #fu_w, fu_h = self.get_filter_size(fu) #fd_w, fd_h = self.get_filter_size(fd) px0, px1, py0, py1 = self.parse_padding(padding) #batch_size, in_h, in_w, channels = x.shape #in_dtype = x.dtype #out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down #out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down x = self.bias_act(x=x, b=b) x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) return x def call(self, inputs): return self.filtered_lrelu(inputs, fu=self.u_filter, fd=self.d_filter, b=self.bias, up=self.u_factor, down=self.d_factor, padding=self.padding, gain=self.gain, slope=self.slope, clamp=self.conv_clamp) def get_config(self): base_config = super(FilteredReLU, self).get_config() return base_config class SynthesisLayer(Layer): def __init__(self, critically_sampled, in_channels, out_channels, in_size, out_size, in_sampling_rate, out_sampling_rate, in_cutoff, out_cutoff, in_half_width, out_half_width, conv_kernel = 3, lrelu_upsampling = 2, filter_size = 6, conv_clamp = 256, use_radial_filters = False, is_torgb = False, **kwargs): super(SynthesisLayer, self).__init__(**kwargs) self.critically_sampled = critically_sampled self.in_channels = in_channels self.out_channels = out_channels self.in_size = np.broadcast_to(np.asarray(in_size), [2]) self.out_size = np.broadcast_to(np.asarray(out_size), [2]) self.in_sampling_rate = in_sampling_rate self.out_sampling_rate = out_sampling_rate self.in_cutoff = in_cutoff self.out_cutoff = out_cutoff self.in_half_width = in_half_width self.out_half_width = out_half_width self.is_torgb = is_torgb self.conv_kernel = 1 if is_torgb else conv_kernel self.lrelu_upsampling = lrelu_upsampling self.conv_clamp = conv_clamp self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) # Up sampling filter self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate) # Down sampling filter self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 self.d_radial = use_radial_filters and not self.critically_sampled self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.d_radial) # Compute padding pad_total = (self.out_size - 1) * self.d_factor + 1 pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor pad_total += self.u_taps + self.d_taps - 2 pad_lo = (pad_total + self.u_factor) // 2 pad_hi = pad_total - pad_lo self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] self.gain = 1 if self.is_torgb else np.sqrt(2) self.slope = 1 if self.is_torgb else 0.2 self.act_funcs = {'linear': {'func': lambda x, **_: x, 'def_alpha': 0, 'def_gain': 1}, 'lrelu': {'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), 'def_alpha': 0.2, 'def_gain': np.sqrt(2)}, } b_init = tf.zeros_initializer() self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), dtype="float32"), trainable=True) self.affine = Dense(self.in_channels) self.conv = Conv2DMod(self.out_channels, kernel_size=self.conv_kernel, padding='same') def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): if numtaps == 1: return None if not radial: f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) return f x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs r = np.hypot(*np.meshgrid(x, x)) f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) w = np.kaiser(numtaps, beta) f *= np.outer(w, w) f /= np.sum(f) return f def get_filter_size(self, f): if f is None: return 1, 1 assert 1 <= f.ndim <= 2 return f.shape[-1], f.shape[0] # width, height def parse_padding(self, padding): if isinstance(padding, int): padding = [padding, padding] assert isinstance(padding, (list, tuple)) assert all(isinstance(x, (int, np.integer)) for x in padding) padding = [int(x) for x in padding] if len(padding) == 2: px, py = padding padding = [px, px, py, py] px0, px1, py0, py1 = padding return px0, px1, py0, py1 def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): spec = self.act_funcs[act] alpha = float(alpha if alpha is not None else spec['def_alpha']) gain = float(gain if gain is not None else spec['def_gain']) clamp = float(clamp if clamp is not None else -1) if b is not None: x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) x = spec['func'](x, alpha=alpha) if gain != 1: x = x * gain if clamp >= 0: x = tf.clip_by_value(x, -clamp, clamp) return x def parse_scaling(self, scaling): if isinstance(scaling, int): scaling = [scaling, scaling] sx, sy = scaling assert sx >= 1 and sy >= 1 return sx, sy def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): if f is None: f = tf.ones([1, 1], dtype=tf.float32) batch_size, in_height, in_width, num_channels = x.shape upx, upy = self.parse_scaling(up) downx, downy = self.parse_scaling(down) padx0, padx1, pady0, pady1 = self.parse_padding(padding) upW = in_width * upx + padx0 + padx1 upH = in_height * upy + pady0 + pady1 assert upW >= f.shape[-1] and upH >= f.shape[0] # Channel first format. x = tf.transpose(x, perm=[0, 3, 1, 2]) # Upsample by inserting zeros. x = tf.reshape(x, [num_channels, batch_size, in_height, 1, in_width, 1]) x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) # Pad or crop. x = tf.pad(x, [[0, 0], [0, 0], [tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], [tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) x = x[:, :, tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] # Setup filter. f = f * (gain ** (f.ndim / 2)) f = tf.cast(f, dtype=x.dtype) if not flip_filter: f = tf.reverse(f, axis=[-1]) f = tf.reshape(f, shape=(1, 1, f.shape[-1])) f = tf.repeat(f, repeats=num_channels, axis=0) if tf.rank(f) == 4: f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) x = tf.nn.conv2d(x, f_0, 1, 'VALID') else: f_0 = tf.expand_dims(f, axis=2) f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) f_1 = tf.expand_dims(f, axis=3) f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') x = x[:, :, ::downy, ::downx] # Back to channel last. x = tf.transpose(x, perm=[0, 2, 3, 1]) return x def filtered_lrelu(self, x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): #fu_w, fu_h = self.get_filter_size(fu) #fd_w, fd_h = self.get_filter_size(fd) px0, px1, py0, py1 = self.parse_padding(padding) #batch_size, in_h, in_w, channels = x.shape #in_dtype = x.dtype #out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down #out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down x = self.bias_act(x=x, b=b) x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) return x def call(self, inputs): x, w = inputs styles = self.affine(w) x = self.conv([x, styles]) x = self.filtered_lrelu(x, fu=self.u_filter, fd=self.d_filter, b=self.bias, up=self.u_factor, down=self.d_factor, padding=self.padding, gain=self.gain, slope=self.slope, clamp=self.conv_clamp) return x def get_config(self): base_config = super(SynthesisLayer, self).get_config() return base_config class SynthesisLayerNoMod(Layer): def __init__(self, critically_sampled, in_channels, out_channels, in_size, out_size, in_sampling_rate, out_sampling_rate, in_cutoff, out_cutoff, in_half_width, out_half_width, conv_kernel = 3, lrelu_upsampling = 2, filter_size = 6, conv_clamp = 256, use_radial_filters = False, is_torgb = False, batch_size = 10, **kwargs): super(SynthesisLayerNoMod, self).__init__(**kwargs) self.critically_sampled = critically_sampled self.bs = batch_size self.in_channels = in_channels self.out_channels = out_channels self.in_size = np.broadcast_to(np.asarray(in_size), [2]) self.out_size = np.broadcast_to(np.asarray(out_size), [2]) self.in_sampling_rate = in_sampling_rate self.out_sampling_rate = out_sampling_rate self.in_cutoff = in_cutoff self.out_cutoff = out_cutoff self.in_half_width = in_half_width self.out_half_width = out_half_width self.is_torgb = is_torgb self.conv_kernel = 1 if is_torgb else conv_kernel self.lrelu_upsampling = lrelu_upsampling self.conv_clamp = conv_clamp self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) # Up sampling filter self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate) # Down sampling filter self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 self.d_radial = use_radial_filters and not self.critically_sampled self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.d_radial) # Compute padding pad_total = (self.out_size - 1) * self.d_factor + 1 pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor pad_total += self.u_taps + self.d_taps - 2 pad_lo = (pad_total + self.u_factor) // 2 pad_hi = pad_total - pad_lo self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] self.gain = 1 if self.is_torgb else np.sqrt(2) self.slope = 1 if self.is_torgb else 0.2 self.act_funcs = {'linear': {'func': lambda x, **_: x, 'def_alpha': 0, 'def_gain': 1}, 'lrelu': {'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), 'def_alpha': 0.2, 'def_gain': np.sqrt(2)}, } b_init = tf.zeros_initializer() self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), dtype="float32"), trainable=True) self.conv = Conv2D(self.out_channels, kernel_size=self.conv_kernel, padding='same') def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): if numtaps == 1: return None if not radial: f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) return f x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs r = np.hypot(*np.meshgrid(x, x)) f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) w = np.kaiser(numtaps, beta) f *= np.outer(w, w) f /= np.sum(f) return f def get_filter_size(self, f): if f is None: return 1, 1 assert 1 <= f.ndim <= 2 return f.shape[-1], f.shape[0] # width, height def parse_padding(self, padding): if isinstance(padding, int): padding = [padding, padding] assert isinstance(padding, (list, tuple)) assert all(isinstance(x, (int, np.integer)) for x in padding) padding = [int(x) for x in padding] if len(padding) == 2: px, py = padding padding = [px, px, py, py] px0, px1, py0, py1 = padding return px0, px1, py0, py1 @tf.function def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): spec = self.act_funcs[act] alpha = float(alpha if alpha is not None else spec['def_alpha']) gain = float(gain if gain is not None else spec['def_gain']) clamp = float(clamp if clamp is not None else -1) if b is not None: x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) x = spec['func'](x, alpha=alpha) if gain != 1: x = x * gain if clamp >= 0: x = tf.clip_by_value(x, -clamp, clamp) return x def parse_scaling(self, scaling): if isinstance(scaling, int): scaling = [scaling, scaling] sx, sy = scaling assert sx >= 1 and sy >= 1 return sx, sy @tf.function def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): if f is None: f = tf.ones([1, 1], dtype=tf.float32) batch_size, in_height, in_width, num_channels = x.shape batch_size = tf.shape(x)[0] upx, upy = self.parse_scaling(up) downx, downy = self.parse_scaling(down) padx0, padx1, pady0, pady1 = self.parse_padding(padding) upW = in_width * upx + padx0 + padx1 upH = in_height * upy + pady0 + pady1 assert upW >= f.shape[-1] and upH >= f.shape[0] # Channel first format. x = tf.transpose(x, perm=[0, 3, 1, 2]) # Upsample by inserting zeros. x = tf.reshape(x, [batch_size, num_channels, in_height, 1, in_width, 1]) x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) # Pad or crop. x = tf.pad(x, [[0, 0], [0, 0], [tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], [tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) x = x[:, :, tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] # Setup filter. f = f * (gain ** (tf.rank(f) / 2)) f = tf.cast(f, dtype=x.dtype) if not flip_filter: f = tf.reverse(f, axis=[-1]) f = tf.reshape(f, shape=(1, 1, f.shape[-1])) f = tf.repeat(f, repeats=num_channels, axis=0) #if tf.rank(f) == 500: # f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) # x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') #else: f_0 = tf.expand_dims(f, axis=2) f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) f_1 = tf.expand_dims(f, axis=3) f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') x = x[:, :, ::downy, ::downx] # Back to channel last. x = tf.transpose(x, perm=[0, 2, 3, 1]) return x @tf.function def filtered_lrelu(self, x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): #fu_w, fu_h = self.get_filter_size(fu) #fd_w, fd_h = self.get_filter_size(fd) px0, px1, py0, py1 = self.parse_padding(padding) #batch_size, in_h, in_w, channels = x.shape #in_dtype = x.dtype #out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down #out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down x = self.bias_act(x=x, b=b) x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) return x def call(self, inputs): x = inputs x = self.conv(x) x = self.filtered_lrelu(x, fu=self.u_filter, fd=self.d_filter, b=self.bias, up=self.u_factor, down=self.d_factor, padding=self.padding, gain=self.gain, slope=self.slope, clamp=self.conv_clamp) return x def get_config(self): base_config = super(SynthesisLayer, self).get_config() return base_config class SynthesisLayerNoModBN(Layer): def __init__(self, critically_sampled, in_channels, out_channels, in_size, out_size, in_sampling_rate, out_sampling_rate, in_cutoff, out_cutoff, in_half_width, out_half_width, conv_kernel = 3, lrelu_upsampling = 2, filter_size = 6, conv_clamp = 256, use_radial_filters = False, is_torgb = False, batch_size = 10, **kwargs): super(SynthesisLayerNoModBN, self).__init__(**kwargs) self.critically_sampled = critically_sampled self.bs = batch_size self.in_channels = in_channels self.out_channels = out_channels self.in_size = np.broadcast_to(np.asarray(in_size), [2]) self.out_size = np.broadcast_to(np.asarray(out_size), [2]) self.in_sampling_rate = in_sampling_rate self.out_sampling_rate = out_sampling_rate self.in_cutoff = in_cutoff self.out_cutoff = out_cutoff self.in_half_width = in_half_width self.out_half_width = out_half_width self.is_torgb = is_torgb self.conv_kernel = 1 if is_torgb else conv_kernel self.lrelu_upsampling = lrelu_upsampling self.conv_clamp = conv_clamp self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1.0 if is_torgb else lrelu_upsampling) # Up sampling filter self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate) # Down sampling filter self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 self.d_radial = use_radial_filters and not self.critically_sampled self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.d_radial) # Compute padding pad_total = (self.out_size - 1) * self.d_factor + 1 pad_total -= (self.in_size) * self.u_factor pad_total += self.u_taps + self.d_taps - 2 pad_lo = (pad_total + self.u_factor) // 2 pad_hi = pad_total - pad_lo self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] self.gain = 1 if self.is_torgb else np.sqrt(2) self.slope = 1 if self.is_torgb else 0.2 self.act_funcs = {'linear': {'func': lambda x, **_: x, 'def_alpha': 0, 'def_gain': 1}, 'lrelu': {'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), 'def_alpha': 0.2, 'def_gain': np.sqrt(2)}, } b_init = tf.zeros_initializer() self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), dtype="float32"), trainable=True) self.conv = Conv2D(self.out_channels, kernel_size=self.conv_kernel, padding='same') self.bn = BatchNormalization() def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): if numtaps == 1: return None if not radial: f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) return f x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs r = np.hypot(*np.meshgrid(x, x)) f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) w = np.kaiser(numtaps, beta) f *= np.outer(w, w) f /= np.sum(f) return f def get_filter_size(self, f): if f is None: return 1, 1 assert 1 <= f.ndim <= 2 return f.shape[-1], f.shape[0] # width, height def parse_padding(self, padding): if isinstance(padding, int): padding = [padding, padding] assert isinstance(padding, (list, tuple)) assert all(isinstance(x, (int, np.integer)) for x in padding) padding = [int(x) for x in padding] if len(padding) == 2: px, py = padding padding = [px, px, py, py] px0, px1, py0, py1 = padding return px0, px1, py0, py1 @tf.function def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): spec = self.act_funcs[act] alpha = float(alpha if alpha is not None else spec['def_alpha']) gain = tf.cast(gain if gain is not None else spec['def_gain'], tf.float32) clamp = float(clamp if clamp is not None else -1) if b is not None: x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) x = spec['func'](x, alpha=alpha) x = x * gain if clamp >= 0: x = tf.clip_by_value(x, -clamp, clamp) return x def parse_scaling(self, scaling): if isinstance(scaling, int): scaling = [scaling, scaling] sx, sy = scaling assert sx >= 1 and sy >= 1 return sx, sy @tf.function def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): if f is None: f = tf.ones([1, 1], dtype=tf.float32) batch_size, in_height, in_width, num_channels = x.shape batch_size = tf.shape(x)[0] upx, upy = self.parse_scaling(up) downx, downy = self.parse_scaling(down) padx0, padx1, pady0, pady1 = self.parse_padding(padding) upW = in_width * upx + padx0 + padx1 upH = in_height * upy + pady0 + pady1 assert upW >= f.shape[-1] and upH >= f.shape[0] # Channel first format. x = tf.transpose(x, perm=[0, 3, 1, 2]) # Upsample by inserting zeros. x = tf.reshape(x, [batch_size, num_channels, in_height, 1, in_width, 1]) x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) # Pad or crop. x = tf.pad(x, [[0, 0], [0, 0], [tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], [tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) x = x[:, :, tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] # Setup filter. f = f * (gain ** (tf.rank(f) / 2)) f = tf.cast(f, dtype=x.dtype) if not flip_filter: f = tf.reverse(f, axis=[-1]) f = tf.reshape(f, shape=(1, 1, f.shape[-1])) f = tf.repeat(f, repeats=num_channels, axis=0) #if tf.rank(f) == 500: # f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) # x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') #else: f_0 = tf.expand_dims(f, axis=2) f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) f_1 = tf.expand_dims(f, axis=3) f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') x = x[:, :, ::downy, ::downx] # Back to channel last. x = tf.transpose(x, perm=[0, 2, 3, 1]) return x @tf.function def filtered_lrelu(self, x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): #fu_w, fu_h = self.get_filter_size(fu) #fd_w, fd_h = self.get_filter_size(fd) px0, px1, py0, py1 = self.parse_padding(padding) #batch_size, in_h, in_w, channels = x.shape #in_dtype = x.dtype #out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down #out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down x = self.bias_act(x=x, b=b) x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) return x def call(self, inputs): x = inputs x = self.conv(x) x = self.bn(x) x = self.filtered_lrelu(x, fu=self.u_filter, fd=self.d_filter, b=self.bias, up=self.u_factor, down=self.d_factor, padding=self.padding, gain=self.gain, slope=self.slope, clamp=self.conv_clamp) return x def get_config(self): base_config = super(SynthesisLayerNoModBN, self).get_config() return base_config class SynthesisInput(Layer): def __init__(self, w_dim, channels, size, sampling_rate, bandwidth, **kwargs): super(SynthesisInput, self).__init__(**kwargs) self.w_dim = w_dim self.channels = channels self.size = np.broadcast_to(np.asarray(size), [2]) self.sampling_rate = sampling_rate self.bandwidth = bandwidth # Draw random frequencies from uniform 2D disc. freqs = np.random.normal(size=(int(channels[0]), 2)) radii = np.sqrt(np.sum(np.square(freqs), axis=1, keepdims=True)) freqs /= radii * np.power(np.exp(np.square(radii)), 0.25) freqs *= bandwidth phases = np.random.uniform(size=[int(channels[0])]) - 0.5 # Setup parameters and buffers. w_init = tf.random_normal_initializer() self.weight = tf.Variable(initial_value=w_init(shape=(self.channels, self.channels), dtype="float32"), rainable=True) self.affine = Dense(4, kernel_initializer=tf.zeros_initializer, bias_initializer=tf.zeros_initializer) self.transform = tf.eye(3, 3) self.freqs = tf.constant(freqs) self.phases = tf.constant(phases) def call(self, w): # Batch dimension transforms = tf.expand_dims(self.transform, axis=0) freqs = tf.expand_dims(self.freqs, axis=0) phases = tf.expand_dims(self.phases, axis=0) # Apply learned transformation. t = self.affine(w) # t = (r_c, r_s, t_x, t_y) t = t / tf.linalg.norm(t[:, :2], axis=1, keepdims=True) # Inverse rotation wrt. resulting image. m_r = tf.repeat(tf.expand_dims(tf.eye(3), axis=0), repeats=w.shape[0], axis=0) m_r[:, 0, 0] = t[:, 0] # r'_c m_r[:, 0, 1] = -t[:, 1] # r'_s m_r[:, 1, 0] = t[:, 1] # r'_s m_r[:, 1, 1] = t[:, 0] # r'_c # Inverse translation wrt. resulting image. m_t = tf.repeat(tf.expand_dims(tf.eye(3), axis=0), repeats=w.shape[0], axis=0) m_t[:, 0, 2] = -t[:, 2] # t'_x m_t[:, 1, 2] = -t[:, 3] # t'_y transforms = m_r @ m_t @ transforms # Transform frequencies. phases = phases + tf.expand_dims(freqs @ transforms[:, :2, 2:], axis=2) freqs = freqs @ transforms[:, :2, :2] # Dampen out-of-band frequencies that may occur due to the user-specified transform. amplitudes = tf.clip_by_value(1 - (tf.linalg.norm(freqs, axis=1, keepdims=True) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth), 0, 1) def get_config(self): base_config = super(SynthesisInput, self).get_config() return base_config class SynthesisLayerFS(Layer): def __init__(self, critically_sampled, in_channels, out_channels, in_size, out_size, in_sampling_rate, out_sampling_rate, in_cutoff, out_cutoff, in_half_width, out_half_width, conv_kernel = 3, lrelu_upsampling = 2, filter_size = 6, conv_clamp = 256, use_radial_filters = False, is_torgb = False, **kwargs): super(SynthesisLayerFS, self).__init__(**kwargs) self.critically_sampled = critically_sampled self.in_channels = in_channels self.out_channels = out_channels self.in_size = np.broadcast_to(np.asarray(in_size), [2]) self.out_size = np.broadcast_to(np.asarray(out_size), [2]) self.in_sampling_rate = in_sampling_rate self.out_sampling_rate = out_sampling_rate self.in_cutoff = in_cutoff self.out_cutoff = out_cutoff self.in_half_width = in_half_width self.out_half_width = out_half_width self.is_torgb = is_torgb self.conv_kernel = 1 if is_torgb else conv_kernel self.lrelu_upsampling = lrelu_upsampling self.conv_clamp = conv_clamp self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) # Up sampling filter self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate) # Down sampling filter self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 self.d_radial = use_radial_filters and not self.critically_sampled self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.d_radial) # Compute padding pad_total = (self.out_size - 1) * self.d_factor + 1 pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor pad_total += self.u_taps + self.d_taps - 2 pad_lo = (pad_total + self.u_factor) // 2 pad_hi = pad_total - pad_lo self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] self.gain = 1 if self.is_torgb else np.sqrt(2) self.slope = 1 if self.is_torgb else 0.2 self.act_funcs = {'linear': {'func': lambda x, **_: x, 'def_alpha': 0, 'def_gain': 1}, 'lrelu': {'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), 'def_alpha': 0.2, 'def_gain': np.sqrt(2)}, } b_init = tf.zeros_initializer() self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), dtype="float32"), trainable=True) self.affine = Dense(self.out_channels) self.conv_mod = Conv2DMod(self.out_channels, kernel_size=self.conv_kernel, padding='same') self.bn = BatchNormalization() self.conv_gamma = Conv2D(self.out_channels, kernel_size=1) self.conv_beta = Conv2D(self.out_channels, kernel_size=1) self.conv_gate = Conv2D(self.out_channels, kernel_size=1) self.conv_final = Conv2D(self.out_channels, kernel_size=self.conv_kernel, padding='same') def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): if numtaps == 1: return None if not radial: f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) return f x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs r = np.hypot(*np.meshgrid(x, x)) f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) w = np.kaiser(numtaps, beta) f *= np.outer(w, w) f /= np.sum(f) return f def get_filter_size(self, f): if f is None: return 1, 1 assert 1 <= f.ndim <= 2 return f.shape[-1], f.shape[0] # width, height def parse_padding(self, padding): if isinstance(padding, int): padding = [padding, padding] assert isinstance(padding, (list, tuple)) assert all(isinstance(x, (int, np.integer)) for x in padding) padding = [int(x) for x in padding] if len(padding) == 2: px, py = padding padding = [px, px, py, py] px0, px1, py0, py1 = padding return px0, px1, py0, py1 @tf.function def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): spec = self.act_funcs[act] alpha = float(alpha if alpha is not None else spec['def_alpha']) gain = tf.cast(gain if gain is not None else spec['def_gain'], tf.float32) clamp = float(clamp if clamp is not None else -1) if b is not None: x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) x = spec['func'](x, alpha=alpha) x = x * gain if clamp >= 0: x = tf.clip_by_value(x, -clamp, clamp) return x def parse_scaling(self, scaling): if isinstance(scaling, int): scaling = [scaling, scaling] sx, sy = scaling assert sx >= 1 and sy >= 1 return sx, sy @tf.function def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): if f is None: f = tf.ones([1, 1], dtype=tf.float32) batch_size, in_height, in_width, num_channels = x.shape batch_size = tf.shape(x)[0] upx, upy = self.parse_scaling(up) downx, downy = self.parse_scaling(down) padx0, padx1, pady0, pady1 = self.parse_padding(padding) upW = in_width * upx + padx0 + padx1 upH = in_height * upy + pady0 + pady1 assert upW >= f.shape[-1] and upH >= f.shape[0] # Channel first format. x = tf.transpose(x, perm=[0, 3, 1, 2]) # Upsample by inserting zeros. x = tf.reshape(x, [batch_size, num_channels, in_height, 1, in_width, 1]) x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) # Pad or crop. x = tf.pad(x, [[0, 0], [0, 0], [tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], [tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) x = x[:, :, tf.math.maximum(-pady0, 0): x.shape[2] - tf.math.maximum(-pady1, 0), tf.math.maximum(-padx0, 0): x.shape[3] - tf.math.maximum(-padx1, 0)] # Setup filter. f = f * (gain ** (tf.rank(f) / 2)) f = tf.cast(f, dtype=x.dtype) if not flip_filter: f = tf.reverse(f, axis=[-1]) f = tf.reshape(f, shape=(1, 1, f.shape[-1])) f = tf.repeat(f, repeats=num_channels, axis=0) # if tf.rank(f) == 500: # f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) # x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') # else: f_0 = tf.expand_dims(f, axis=2) f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) f_1 = tf.expand_dims(f, axis=3) f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') x = x[:, :, ::downy, ::downx] # Back to channel last. x = tf.transpose(x, perm=[0, 2, 3, 1]) return x @tf.function def filtered_lrelu(self, x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): #fu_w, fu_h = self.get_filter_size(fu) #fd_w, fd_h = self.get_filter_size(fd) px0, px1, py0, py1 = self.parse_padding(padding) #batch_size, in_h, in_w, channels = x.shape #in_dtype = x.dtype #out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down #out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down x = self.bias_act(x=x, b=b) x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) return x @tf.function def aadm(self, x, w, a): w_affine = self.affine(w) x_norm = self.bn(x) x_id = self.conv_mod([x_norm, w_affine]) gate = self.conv_gate(x_norm) gate = tf.nn.sigmoid(gate) x_att_beta = self.conv_beta(a) x_att_gamma = self.conv_gamma(a) x_att = x_norm * x_att_beta + x_att_gamma h = x_id * gate + (1 - gate) * x_att return h def call(self, inputs): x, w, a = inputs x = self.conv_final(x) x = self.aadm(x, w, a) x = self.filtered_lrelu(x, fu=self.u_filter, fd=self.d_filter, b=self.bias, up=self.u_factor, down=self.d_factor, padding=self.padding, gain=self.gain, slope=self.slope, clamp=self.conv_clamp) return x def get_config(self): base_config = super(SynthesisLayerFS, self).get_config() return base_config class SynthesisLayerUpDownOnly(Layer): def __init__(self, critically_sampled, in_channels, out_channels, in_size, out_size, in_sampling_rate, out_sampling_rate, in_cutoff, out_cutoff, in_half_width, out_half_width, conv_kernel = 3, lrelu_upsampling = 2, filter_size = 6, conv_clamp = 256, use_radial_filters = False, is_torgb = False, **kwargs): super(SynthesisLayerUpDownOnly, self).__init__(**kwargs) self.critically_sampled = critically_sampled self.in_channels = in_channels self.out_channels = out_channels self.in_size = np.broadcast_to(np.asarray(in_size), [2]) self.out_size = np.broadcast_to(np.asarray(out_size), [2]) self.in_sampling_rate = in_sampling_rate self.out_sampling_rate = out_sampling_rate self.in_cutoff = in_cutoff self.out_cutoff = out_cutoff self.in_half_width = in_half_width self.out_half_width = out_half_width self.is_torgb = is_torgb self.conv_kernel = 1 if is_torgb else conv_kernel self.lrelu_upsampling = lrelu_upsampling self.conv_clamp = conv_clamp self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) # Up sampling filter self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate) # Down sampling filter self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 self.d_radial = use_radial_filters and not self.critically_sampled self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.d_radial) # Compute padding pad_total = (self.out_size - 1) * self.d_factor + 1 pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor pad_total += self.u_taps + self.d_taps - 2 pad_lo = (pad_total + self.u_factor) // 2 pad_hi = pad_total - pad_lo self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] self.gain = 1 if self.is_torgb else np.sqrt(2) self.slope = 1 if self.is_torgb else 0.2 self.act_funcs = {'linear': {'func': lambda x, **_: x, 'def_alpha': 0, 'def_gain': 1}, 'lrelu': {'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), 'def_alpha': 0.2, 'def_gain': np.sqrt(2)}, } def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): if numtaps == 1: return None if not radial: f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) return f x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs r = np.hypot(*np.meshgrid(x, x)) f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) w = np.kaiser(numtaps, beta) f *= np.outer(w, w) f /= np.sum(f) return f def get_filter_size(self, f): if f is None: return 1, 1 assert 1 <= f.ndim <= 2 return f.shape[-1], f.shape[0] # width, height def parse_padding(self, padding): if isinstance(padding, int): padding = [padding, padding] assert isinstance(padding, (list, tuple)) assert all(isinstance(x, (int, np.integer)) for x in padding) padding = [int(x) for x in padding] if len(padding) == 2: px, py = padding padding = [px, px, py, py] px0, px1, py0, py1 = padding return px0, px1, py0, py1 def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): spec = self.act_funcs[act] alpha = float(alpha if alpha is not None else spec['def_alpha']) gain = float(gain if gain is not None else spec['def_gain']) clamp = float(clamp if clamp is not None else -1) if b is not None: x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) x = spec['func'](x, alpha=alpha) if gain != 1: x = x * gain if clamp >= 0: x = tf.clip_by_value(x, -clamp, clamp) return x def parse_scaling(self, scaling): if isinstance(scaling, int): scaling = [scaling, scaling] sx, sy = scaling assert sx >= 1 and sy >= 1 return sx, sy def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): if f is None: f = tf.ones([1, 1], dtype=tf.float32) batch_size, in_height, in_width, num_channels = x.shape upx, upy = self.parse_scaling(up) downx, downy = self.parse_scaling(down) padx0, padx1, pady0, pady1 = self.parse_padding(padding) upW = in_width * upx + padx0 + padx1 upH = in_height * upy + pady0 + pady1 assert upW >= f.shape[-1] and upH >= f.shape[0] # Channel first format. x = tf.transpose(x, perm=[0, 3, 1, 2]) # Upsample by inserting zeros. x = tf.reshape(x, [num_channels, batch_size, in_height, 1, in_width, 1]) x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) # Pad or crop. x = tf.pad(x, [[0, 0], [0, 0], [tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], [tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) x = x[:, :, tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] # Setup filter. f = f * (gain ** (f.ndim / 2)) f = tf.cast(f, dtype=x.dtype) if not flip_filter: f = tf.reverse(f, axis=[-1]) f = tf.reshape(f, shape=(1, 1, f.shape[-1])) f = tf.repeat(f, repeats=num_channels, axis=0) if tf.rank(f) == 4: f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) x = tf.nn.conv2d(x, f_0, 1, 'VALID') else: f_0 = tf.expand_dims(f, axis=2) f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) f_1 = tf.expand_dims(f, axis=3) f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') x = x[:, :, ::downy, ::downx] # Back to channel last. x = tf.transpose(x, perm=[0, 2, 3, 1]) return x def filtered_lrelu(self, x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): px0, px1, py0, py1 = self.parse_padding(padding) x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) return x def call(self, inputs): x = inputs x = self.filtered_lrelu(x, fu=self.u_filter, fd=self.d_filter, b=self.bias, up=self.u_factor, down=self.d_factor, padding=self.padding, gain=self.gain, slope=self.slope, clamp=self.conv_clamp) return x def get_config(self): base_config = super(SynthesisLayerUpDownOnly, self).get_config() return base_config class Localization(Layer): def __init__(self): super(Localization, self).__init__() self.pool = MaxPooling2D() self.conv_0 = Conv2D(36, 5, activation='relu') self.conv_1 = Conv2D(36, 5, activation='relu') self.flatten = Flatten() self.fc_0 = Dense(36, activation='relu') self.fc_1 = Dense(6, bias_initializer=tf.keras.initializers.constant([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]), kernel_initializer='zeros') self.reshape = Reshape((2, 3)) def build(self, input_shape): print(input_shape) def compute_output_shape(self, input_shape): return [None, 6] def call(self, inputs): x = self.conv_0(inputs) x = self.pool(x) x = self.conv_1(x) x = self.pool(x) x = self.flatten(x) x = self.fc_0(x) theta = self.fc_1(x) theta = self.reshape(theta) return theta class BilinearInterpolation(Layer): def __init__(self, height=36, width=36): super(BilinearInterpolation, self).__init__() self.height = height self.width = width def compute_output_shape(self, input_shape): return [None, self.height, self.width, 1] def get_config(self): config = { 'height': self.height, 'width': self.width } base_config = super(BilinearInterpolation, self).get_config() return dict(list(base_config.items()) + list(config.items())) def advance_indexing(self, inputs, x, y): shape = tf.shape(inputs) batch_size = shape[0] batch_idx = tf.range(0, batch_size) batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1)) b = tf.tile(batch_idx, (1, self.height, self.width)) indices = tf.stack([b, y, x], 3) return tf.gather_nd(inputs, indices) def grid_generator(self, batch): x = tf.linspace(-1, 1, self.width) y = tf.linspace(-1, 1, self.height) xx, yy = tf.meshgrid(x, y) xx = tf.reshape(xx, (-1,)) yy = tf.reshape(yy, (-1,)) homogenous_coordinates = tf.stack([xx, yy, tf.ones_like(xx)]) homogenous_coordinates = tf.expand_dims(homogenous_coordinates, axis=0) homogenous_coordinates = tf.tile(homogenous_coordinates, [batch, 1, 1]) homogenous_coordinates = tf.cast(homogenous_coordinates, dtype=tf.float32) return homogenous_coordinates def interpolate(self, images, homogenous_coordinates, theta): with tf.name_scope("Transformation"): transformed = tf.matmul(theta, homogenous_coordinates) transformed = tf.transpose(transformed, perm=[0, 2, 1]) transformed = tf.reshape(transformed, [-1, self.height, self.width, 2]) x_transformed = transformed[:, :, :, 0] y_transformed = transformed[:, :, :, 1] x = ((x_transformed + 1.) * tf.cast(self.width, dtype=tf.float32)) * 0.5 y = ((y_transformed + 1.) * tf.cast(self.height, dtype=tf.float32)) * 0.5 with tf.name_scope("VaribleCasting"): x0 = tf.cast(tf.math.floor(x), dtype=tf.int32) x1 = x0 + 1 y0 = tf.cast(tf.math.floor(y), dtype=tf.int32) y1 = y0 + 1 x0 = tf.clip_by_value(x0, 0, self.width-1) x1 = tf.clip_by_value(x1, 0, self.width - 1) y0 = tf.clip_by_value(y0, 0, self.height - 1) y1 = tf.clip_by_value(y1, 0, self.height - 1) x = tf.clip_by_value(x, 0, tf.cast(self.width, dtype=tf.float32) - 1.0) y = tf.clip_by_value(y, 0, tf.cast(self.height, dtype=tf.float32) - 1.0) with tf.name_scope("AdvancedIndexing"): i_a = self.advance_indexing(images, x0, y0) i_b = self.advance_indexing(images, x0, y1) i_c = self.advance_indexing(images, x1, y0) i_d = self.advance_indexing(images, x1, y1) with tf.name_scope("Interpolation"): x0 = tf.cast(x0, dtype=tf.float32) x1 = tf.cast(x1, dtype=tf.float32) y0 = tf.cast(y0, dtype=tf.float32) y1 = tf.cast(y1, dtype=tf.float32) w_a = (x1 - x) * (y1 - y) w_b = (x1 - x) * (y - y0) w_c = (x - x0) * (y1 - y) w_d = (x - x0) * (y - y0) w_a = tf.expand_dims(w_a, axis=3) w_b = tf.expand_dims(w_b, axis=3) w_c = tf.expand_dims(w_c, axis=3) w_d = tf.expand_dims(w_d, axis=3) return tf.math.add_n([w_a * i_a + w_b * i_b + w_c * i_c + w_d * i_d]) def call(self, inputs): images, theta = inputs homogenous_coordinates = self.grid_generator(batch=tf.shape(images)[0]) return self.interpolate(images, homogenous_coordinates, theta) class ResBlockLR(Layer): def __init__(self, filters=16): super(ResBlockLR, self).__init__() self.filters = filters self.conv_0 = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same') self.bn_0 = BatchNormalization() self.conv_1 = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same') self.bn_1 = BatchNormalization() def get_config(self): config = { 'filters': self.filters, } base_config = super(ResBlockLR, self).get_config() return dict(list(base_config.items()) + list(config.items())) def call(self, inputs): x = self.conv_0(inputs) x = self.bn_0(x) x = tf.nn.leaky_relu(x, alpha=0.2) x = self.conv_1(x) x = self.bn_1(x) return x + inputs class LearnedResize(Layer): def __init__(self, width, height, filters=16, in_channels=3, num_res_block=3, interpolation='bilinear'): super(LearnedResize, self).__init__() self.filters = filters self.num_res_block = num_res_block self.interpolation = interpolation self.in_channels = in_channels self.width = width self.height = height self.resize_layer = tf.keras.layers.experimental.preprocessing.Resizing(height, width, interpolation=interpolation) self.init_layers = tf.keras.models.Sequential([Conv2D(filters=filters, kernel_size=7, strides=1, padding='same'), LeakyReLU(0.2), Conv2D(filters=filters, kernel_size=1, strides=1, padding='same'), LeakyReLU(0.2), BatchNormalization() ]) res_blocks = [] for i in range(num_res_block): res_blocks.append(ResBlockLR(filters=filters)) res_blocks.append(Conv2D(filters=filters, kernel_size=3, strides=1, padding='same', use_bias=False)) res_blocks.append(BatchNormalization()) self.res_block_pipe = tf.keras.models.Sequential(res_blocks) self.final_conv = Conv2D(filters=in_channels, kernel_size=3, strides=1, padding='same') def compute_output_shape(self, input_shape): return [None, self.target_size[0], self.target_size[1], input_shape[-1]] def get_config(self): config = { 'filters': self.filters, 'num_res_block': self.num_res_block, 'interpolation': self.interpolation, 'width': self.width, 'height': self.height, } base_config = super(LearnedResize, self).get_config() return dict(list(base_config.items()) + list(config.items())) def call(self, inputs): x_l = self.init_layers(inputs) x_l_0 = self.resize_layer(x_l) x_l = self.res_block_pipe(x_l_0) x_l = x_l + x_l_0 x_l = self.final_conv(x_l) x = self.resize_layer(inputs) return x + x_l