|
import tensorflow as tf |
|
import numpy as np |
|
import scipy.signal as sps |
|
import scipy.special as spspec |
|
|
|
import tensorflow.keras.backend as K |
|
import math |
|
from tensorflow.python.keras.utils import conv_utils |
|
from tensorflow.keras.layers import Layer, InputSpec, Conv2D, LeakyReLU, Dense, BatchNormalization, Input, Concatenate |
|
from tensorflow.keras.layers import Conv2DTranspose, ReLU, Activation, UpSampling2D, Add, Reshape, Multiply |
|
from tensorflow.keras.layers import AveragePooling2D, LayerNormalization, GlobalAveragePooling2D, MaxPooling2D, Flatten |
|
from tensorflow.keras import initializers, constraints, regularizers |
|
from tensorflow.keras.models import Model |
|
from tensorflow_addons.layers import InstanceNormalization |
|
|
|
|
|
def sin_activation(x, omega=30): |
|
return tf.math.sin(omega * x) |
|
|
|
|
|
class AdaIN(Layer): |
|
def __init__(self, **kwargs): |
|
super(AdaIN, self).__init__(**kwargs) |
|
|
|
def build(self, input_shapes): |
|
x_shape = input_shapes[0] |
|
w_shape = input_shapes[1] |
|
|
|
self.w_channels = w_shape[-1] |
|
self.x_channels = x_shape[-1] |
|
|
|
self.dense_1 = Dense(self.x_channels) |
|
self.dense_2 = Dense(self.x_channels) |
|
|
|
def call(self, inputs): |
|
x, w = inputs |
|
ys = tf.reshape(self.dense_1(w), (-1, 1, 1, self.x_channels)) |
|
yb = tf.reshape(self.dense_2(w), (-1, 1, 1, self.x_channels)) |
|
return ys * x + yb |
|
|
|
def get_config(self): |
|
config = { |
|
|
|
|
|
} |
|
base_config = super(AdaIN, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
|
|
class Conv2DMod(Layer): |
|
|
|
def __init__(self, |
|
filters, |
|
kernel_size, |
|
strides=1, |
|
padding='valid', |
|
dilation_rate=1, |
|
kernel_initializer='glorot_uniform', |
|
kernel_regularizer=None, |
|
activity_regularizer=None, |
|
kernel_constraint=None, |
|
demod=True, |
|
**kwargs): |
|
super(Conv2DMod, self).__init__(**kwargs) |
|
self.filters = filters |
|
self.rank = 2 |
|
self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size') |
|
self.strides = conv_utils.normalize_tuple(strides, 2, 'strides') |
|
self.padding = conv_utils.normalize_padding(padding) |
|
self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2, 'dilation_rate') |
|
self.kernel_initializer = initializers.get(kernel_initializer) |
|
self.kernel_regularizer = regularizers.get(kernel_regularizer) |
|
self.activity_regularizer = regularizers.get(activity_regularizer) |
|
self.kernel_constraint = constraints.get(kernel_constraint) |
|
self.demod = demod |
|
self.input_spec = [InputSpec(ndim = 4), |
|
InputSpec(ndim = 2)] |
|
|
|
def build(self, input_shape): |
|
channel_axis = -1 |
|
if input_shape[0][channel_axis] is None: |
|
raise ValueError('The channel dimension of the inputs ' |
|
'should be defined. Found `None`.') |
|
input_dim = input_shape[0][channel_axis] |
|
kernel_shape = self.kernel_size + (input_dim, self.filters) |
|
|
|
if input_shape[1][-1] != input_dim: |
|
raise ValueError('The last dimension of modulation input should be equal to input dimension.') |
|
|
|
self.kernel = self.add_weight(shape=kernel_shape, |
|
initializer=self.kernel_initializer, |
|
name='kernel', |
|
regularizer=self.kernel_regularizer, |
|
constraint=self.kernel_constraint) |
|
|
|
|
|
self.input_spec = [InputSpec(ndim=4, axes={channel_axis: input_dim}), |
|
InputSpec(ndim=2)] |
|
self.built = True |
|
|
|
def call(self, inputs): |
|
|
|
|
|
x = tf.transpose(inputs[0], [0, 3, 1, 2]) |
|
|
|
|
|
|
|
w = K.expand_dims(K.expand_dims(K.expand_dims(inputs[1], axis = 1), axis = 1), axis = -1) |
|
|
|
|
|
wo = K.expand_dims(self.kernel, axis = 0) |
|
|
|
|
|
weights = wo * (w+1) |
|
|
|
|
|
if self.demod: |
|
d = K.sqrt(K.sum(K.square(weights), axis=[1,2,3], keepdims = True) + 1e-8) |
|
weights = weights / d |
|
|
|
|
|
x = tf.reshape(x, [1, -1, x.shape[2], x.shape[3]]) |
|
w = tf.reshape(tf.transpose(weights, [1, 2, 3, 0, 4]), [weights.shape[1], weights.shape[2], weights.shape[3], -1]) |
|
|
|
x = tf.nn.conv2d(x, w, |
|
strides=self.strides, |
|
padding="SAME", |
|
data_format="NCHW") |
|
|
|
|
|
x = tf.reshape(x, [-1, self.filters, x.shape[2], x.shape[3]]) |
|
x = tf.transpose(x, [0, 2, 3, 1]) |
|
|
|
return x |
|
|
|
def compute_output_shape(self, input_shape): |
|
space = input_shape[0][1:-1] |
|
new_space = [] |
|
for i in range(len(space)): |
|
new_dim = conv_utils.conv_output_length( |
|
space[i], |
|
self.kernel_size[i], |
|
padding=self.padding, |
|
stride=self.strides[i], |
|
dilation=self.dilation_rate[i]) |
|
new_space.append(new_dim) |
|
|
|
return (input_shape[0],) + tuple(new_space) + (self.filters,) |
|
|
|
def get_config(self): |
|
config = { |
|
'filters': self.filters, |
|
'kernel_size': self.kernel_size, |
|
'strides': self.strides, |
|
'padding': self.padding, |
|
'dilation_rate': self.dilation_rate, |
|
'kernel_initializer': initializers.serialize(self.kernel_initializer), |
|
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), |
|
'activity_regularizer': |
|
regularizers.serialize(self.activity_regularizer), |
|
'kernel_constraint': constraints.serialize(self.kernel_constraint), |
|
'demod': self.demod |
|
} |
|
base_config = super(Conv2DMod, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
|
|
class CreatePatches(tf.keras.layers.Layer ): |
|
|
|
def __init__( self , patch_size): |
|
super( CreatePatches , self).__init__() |
|
self.patch_size = patch_size |
|
|
|
def call(self, inputs): |
|
patches = [] |
|
|
|
input_image_size = inputs.shape[ 1 ] |
|
for i in range( 0 , input_image_size , self.patch_size ): |
|
for j in range( 0 , input_image_size , self.patch_size ): |
|
patches.append( inputs[ : , i : i + self.patch_size , j : j + self.patch_size , : ] ) |
|
return patches |
|
|
|
def get_config(self): |
|
config = {'patch_size': self.patch_size, |
|
} |
|
base_config = super(CreatePatches, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
|
|
class SelfAttention(tf.keras.layers.Layer ): |
|
|
|
def __init__( self , alpha, filters=128): |
|
super(SelfAttention , self).__init__() |
|
self.alpha = alpha |
|
self.filters = filters |
|
|
|
self.f = Conv2D(filters, 1, 1) |
|
self.g = Conv2D(filters, 1, 1) |
|
self.s = Conv2D(filters, 1, 1) |
|
|
|
def call(self, inputs): |
|
|
|
f_map = self.f(inputs) |
|
f_map = tf.image.transpose(f_map) |
|
|
|
g_map = self.g(inputs) |
|
|
|
s_map = self.s(inputs) |
|
|
|
att = f_map * g_map |
|
|
|
att = att / self.alpha |
|
|
|
return tf.keras.activations.softmax(att + s_map, axis=0) |
|
|
|
def get_config(self): |
|
config = {'alpha': self.alpha, |
|
} |
|
base_config = super(SelfAttention, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
|
|
class Sampling(Layer): |
|
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.""" |
|
|
|
def call(self, inputs): |
|
z_mean, z_log_var = inputs |
|
batch = tf.shape(z_mean)[0] |
|
dim = tf.shape(z_mean)[1] |
|
epsilon = tf.keras.backend.random_normal(shape=(batch, dim)) |
|
return z_mean + tf.exp(0.5 * z_log_var) * epsilon |
|
|
|
def get_config(self): |
|
config = { |
|
} |
|
base_config = super(Sampling, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
|
|
class ArcMarginPenaltyLogists(tf.keras.layers.Layer): |
|
"""ArcMarginPenaltyLogists""" |
|
def __init__(self, num_classes, margin=0.7, logist_scale=64, **kwargs): |
|
super(ArcMarginPenaltyLogists, self).__init__(**kwargs) |
|
self.num_classes = num_classes |
|
self.margin = margin |
|
self.logist_scale = logist_scale |
|
|
|
def build(self, input_shape): |
|
self.w = self.add_variable( |
|
"weights", shape=[int(input_shape[-1]), self.num_classes]) |
|
self.cos_m = tf.identity(math.cos(self.margin), name='cos_m') |
|
self.sin_m = tf.identity(math.sin(self.margin), name='sin_m') |
|
self.th = tf.identity(math.cos(math.pi - self.margin), name='th') |
|
self.mm = tf.multiply(self.sin_m, self.margin, name='mm') |
|
|
|
def call(self, embds, labels): |
|
normed_embds = tf.nn.l2_normalize(embds, axis=1, name='normed_embd') |
|
normed_w = tf.nn.l2_normalize(self.w, axis=0, name='normed_weights') |
|
|
|
cos_t = tf.matmul(normed_embds, normed_w, name='cos_t') |
|
sin_t = tf.sqrt(1. - cos_t ** 2, name='sin_t') |
|
|
|
cos_mt = tf.subtract( |
|
cos_t * self.cos_m, sin_t * self.sin_m, name='cos_mt') |
|
|
|
cos_mt = tf.where(cos_t > self.th, cos_mt, cos_t - self.mm) |
|
|
|
mask = tf.one_hot(tf.cast(labels, tf.int32), depth=self.num_classes, |
|
name='one_hot_mask') |
|
|
|
logists = tf.where(mask == 1., cos_mt, cos_t) |
|
logists = tf.multiply(logists, self.logist_scale, 'arcface_logist') |
|
|
|
return logists |
|
|
|
def get_config(self): |
|
config = {'num_classes': self.num_classes, |
|
'margin': self.margin, |
|
'logist_scale': self.logist_scale |
|
} |
|
base_config = super(ArcMarginPenaltyLogists, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
class KLLossLayer(tf.keras.layers.Layer): |
|
"""ArcMarginPenaltyLogists""" |
|
def __init__(self, beta=1.5, **kwargs): |
|
super(KLLossLayer, self).__init__(**kwargs) |
|
self.beta = beta |
|
|
|
def call(self, inputs): |
|
z_mean, z_log_var = inputs |
|
|
|
kl_loss = tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) |
|
kl_loss = -0.5 * kl_loss * self.beta |
|
|
|
self.add_loss(kl_loss * 0) |
|
self.add_metric(kl_loss, 'kl_loss') |
|
|
|
return inputs |
|
|
|
def get_config(self): |
|
config = { |
|
'beta': self.beta |
|
} |
|
base_config = super(KLLossLayer, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
|
|
class ReflectionPadding2D(Layer): |
|
def __init__(self, padding=(1, 1), **kwargs): |
|
self.padding = tuple(padding) |
|
self.input_spec = [InputSpec(ndim=4)] |
|
super(ReflectionPadding2D, self).__init__(**kwargs) |
|
|
|
def compute_output_shape(self, s): |
|
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3]) |
|
|
|
def call(self, x, mask=None): |
|
w_pad,h_pad = self.padding |
|
return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT') |
|
|
|
def get_config(self): |
|
config = { |
|
'padding': self.padding, |
|
} |
|
base_config = super(ReflectionPadding2D, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
|
|
class ResBlock(Layer): |
|
|
|
def __init__(self, fil, **kwargs): |
|
super(ResBlock, self).__init__(**kwargs) |
|
self.fil = fil |
|
|
|
self.conv_0 = Conv2D(kernel_size=3, filters=fil, strides=1) |
|
self.conv_1 = Conv2D(kernel_size=3, filters=fil, strides=1) |
|
|
|
self.res = Conv2D(kernel_size=1, filters=1, strides=1) |
|
|
|
self.lrelu = LeakyReLU(0.2) |
|
self.padding = ReflectionPadding2D(padding=(1, 1)) |
|
|
|
def call(self, inputs): |
|
res = self.res(inputs) |
|
|
|
x = self.padding(inputs) |
|
x = self.conv_0(x) |
|
x = self.lrelu(x) |
|
|
|
x = self.padding(x) |
|
x = self.conv_1(x) |
|
x = self.lrelu(x) |
|
|
|
out = x + res |
|
|
|
return out |
|
|
|
def get_config(self): |
|
config = { |
|
'fil': self.fil |
|
} |
|
base_config = super(ResBlock, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
|
|
class SubpixelConv2D(Layer): |
|
""" Subpixel Conv2D Layer |
|
upsampling a layer from (h, w, c) to (h*r, w*r, c/(r*r)), |
|
where r is the scaling factor, default to 4 |
|
# Arguments |
|
upsampling_factor: the scaling factor |
|
# Input shape |
|
Arbitrary. Use the keyword argument `input_shape` |
|
(tuple of integers, does not include the samples axis) |
|
when using this layer as the first layer in a model. |
|
# Output shape |
|
the second and the third dimension increased by a factor of |
|
`upsampling_factor`; the last layer decreased by a factor of |
|
`upsampling_factor^2`. |
|
# References |
|
Real-Time Single Image and Video Super-Resolution Using an Efficient |
|
Sub-Pixel Convolutional Neural Network Shi et Al. https://arxiv.org/abs/1609.05158 |
|
""" |
|
|
|
def __init__(self, upsampling_factor=4, **kwargs): |
|
super(SubpixelConv2D, self).__init__(**kwargs) |
|
self.upsampling_factor = upsampling_factor |
|
|
|
def build(self, input_shape): |
|
last_dim = input_shape[-1] |
|
factor = self.upsampling_factor * self.upsampling_factor |
|
if last_dim % (factor) != 0: |
|
raise ValueError('Channel ' + str(last_dim) + ' should be of ' |
|
'integer times of upsampling_factor^2: ' + |
|
str(factor) + '.') |
|
|
|
def call(self, inputs, **kwargs): |
|
return tf.nn.depth_to_space( inputs, self.upsampling_factor ) |
|
|
|
def get_config(self): |
|
config = { 'upsampling_factor': self.upsampling_factor, } |
|
base_config = super(SubpixelConv2D, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
def compute_output_shape(self, input_shape): |
|
factor = self.upsampling_factor * self.upsampling_factor |
|
input_shape_1 = None |
|
if input_shape[1] is not None: |
|
input_shape_1 = input_shape[1] * self.upsampling_factor |
|
input_shape_2 = None |
|
if input_shape[2] is not None: |
|
input_shape_2 = input_shape[2] * self.upsampling_factor |
|
dims = [ input_shape[0], |
|
input_shape_1, |
|
input_shape_2, |
|
int(input_shape[3]/factor) |
|
] |
|
return tuple( dims ) |
|
|
|
|
|
def id_mod_res(inputs, c): |
|
feature_map, z_id = inputs |
|
|
|
x = Conv2D(c, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(feature_map) |
|
|
|
x = AdaIN()([x, z_id]) |
|
|
|
x = ReLU()(x) |
|
|
|
x = Conv2D(c, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) |
|
|
|
x = AdaIN()([x, z_id]) |
|
|
|
out = Add()([x, feature_map]) |
|
|
|
return out |
|
|
|
|
|
def id_mod_res_v2(inputs, c): |
|
feature_map, z_id = inputs |
|
|
|
affine = Dense(feature_map.shape[-1])(z_id) |
|
x = Conv2DMod(c, kernel_size=3, padding='same', |
|
kernel_regularizer=tf.keras.regularizers.l2(0.0001))([feature_map, affine]) |
|
|
|
x = ReLU()(x) |
|
|
|
affine = Dense(x.shape[-1])(z_id) |
|
x = Conv2DMod(c, kernel_size=3, padding='same', |
|
kernel_regularizer=tf.keras.regularizers.l2(0.0001))([x, affine]) |
|
out = Add()([x, feature_map]) |
|
|
|
x = ReLU()(x) |
|
|
|
return out |
|
|
|
|
|
def simswap(im_size, filter_scale=1, deep=True): |
|
inputs = Input(shape=(im_size, im_size, 3)) |
|
z_id = Input(shape=(512,)) |
|
|
|
x = ReflectionPadding2D(padding=(3, 3))(inputs) |
|
x = Conv2D(filters=64 // filter_scale, kernel_size=7, padding='valid', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation(tf.keras.activations.relu)(x) |
|
|
|
x = Conv2D(filters=64 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation(tf.keras.activations.relu)(x) |
|
|
|
x = Conv2D(filters=256 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation(tf.keras.activations.relu)(x) |
|
|
|
x = Conv2D(filters=512 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation(tf.keras.activations.relu)(x) |
|
|
|
if deep: |
|
x = Conv2D(filters=512 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation(tf.keras.activations.relu)(x) |
|
|
|
x = id_mod_res([x, z_id], 512 // filter_scale) |
|
|
|
x = id_mod_res([x, z_id], 512 // filter_scale) |
|
|
|
x = id_mod_res([x, z_id], 512 // filter_scale) |
|
|
|
x = id_mod_res([x, z_id], 512 // filter_scale) |
|
|
|
if deep: |
|
x = SubpixelConv2D(upsampling_factor=2)(x) |
|
x = Conv2D(filters=512 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation(tf.keras.activations.relu)(x) |
|
|
|
x = SubpixelConv2D(upsampling_factor=2)(x) |
|
x = Conv2D(filters=256 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation(tf.keras.activations.relu)(x) |
|
|
|
x = SubpixelConv2D(upsampling_factor=2)(x) |
|
x = Conv2D(filters=128 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation(tf.keras.activations.relu)(x) |
|
|
|
x = SubpixelConv2D(upsampling_factor=2)(x) |
|
x = Conv2D(filters=64 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation(tf.keras.activations.relu)(x) |
|
|
|
x = ReflectionPadding2D(padding=(3, 3))(x) |
|
out = Conv2D(filters=3, kernel_size=7, padding='valid')(x) |
|
|
|
model = Model([inputs, z_id], out) |
|
model.summary() |
|
|
|
return model |
|
|
|
|
|
def simswap_v2(deep=True): |
|
inputs = Input(shape=(224, 224, 3)) |
|
z_id = Input(shape=(512,)) |
|
|
|
x = ReflectionPadding2D(padding=(3, 3))(inputs) |
|
x = Conv2D(filters=64, kernel_size=7, padding='valid', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation(tf.keras.activations.relu)(x) |
|
|
|
x = Conv2D(filters=64, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation(tf.keras.activations.relu)(x) |
|
|
|
x = Conv2D(filters=256, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation(tf.keras.activations.relu)(x) |
|
|
|
x = Conv2D(filters=512, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation(tf.keras.activations.relu)(x) |
|
|
|
if deep: |
|
x = Conv2D(filters=512, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation(tf.keras.activations.relu)(x) |
|
|
|
x = id_mod_res_v2([x, z_id], 512) |
|
|
|
x = id_mod_res_v2([x, z_id], 512) |
|
|
|
x = id_mod_res_v2([x, z_id], 512) |
|
|
|
x = id_mod_res_v2([x, z_id], 512) |
|
|
|
x = id_mod_res_v2([x, z_id], 512) |
|
|
|
x = id_mod_res_v2([x, z_id], 512) |
|
|
|
if deep: |
|
x = UpSampling2D(interpolation='bilinear')(x) |
|
x = Conv2D(filters=512, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation(tf.keras.activations.relu)(x) |
|
|
|
x = UpSampling2D(interpolation='bilinear')(x) |
|
x = Conv2D(filters=256, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation(tf.keras.activations.relu)(x) |
|
|
|
x = UpSampling2D(interpolation='bilinear')(x) |
|
x = Conv2D(filters=128, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation(tf.keras.activations.relu)(x) |
|
|
|
x = UpSampling2D(interpolation='bilinear')(x) |
|
x = Conv2D(filters=64, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation(tf.keras.activations.relu)(x) |
|
|
|
x = ReflectionPadding2D(padding=(3, 3))(x) |
|
out = Conv2D(filters=3, kernel_size=7, padding='valid')(x) |
|
out = Activation('sigmoid')(out) |
|
|
|
model = Model([inputs, z_id], out) |
|
model.summary() |
|
|
|
return model |
|
|
|
|
|
class AdaptiveAttention(Layer): |
|
|
|
def __init__(self, **kwargs): |
|
super(AdaptiveAttention, self).__init__(**kwargs) |
|
|
|
def call(self, inputs): |
|
m, a, i = inputs |
|
return (1 - m) * a + m * i |
|
|
|
def get_config(self): |
|
base_config = super(AdaptiveAttention, self).get_config() |
|
return base_config |
|
|
|
|
|
def aad_block(inputs, c_out): |
|
h, z_att, z_id = inputs |
|
|
|
h_norm = BatchNormalization()(h) |
|
h = Conv2D(filters=c_out, kernel_size=1, kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(h_norm) |
|
|
|
m = Activation('sigmoid')(h) |
|
|
|
z_att_gamma = Conv2D(filters=c_out, |
|
kernel_size=1, |
|
kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_att) |
|
|
|
z_att_beta = Conv2D(filters=c_out, |
|
kernel_size=1, |
|
kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_att) |
|
|
|
a = Multiply()([h_norm, z_att_gamma]) |
|
a = Add()([a, z_att_beta]) |
|
|
|
z_id_gamma = Dense(h_norm.shape[-1], |
|
kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_id) |
|
z_id_gamma = Reshape(target_shape=(1, 1, h_norm.shape[-1]))(z_id_gamma) |
|
|
|
z_id_beta = Dense(h_norm.shape[-1], |
|
kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_id) |
|
z_id_beta = Reshape(target_shape=(1, 1, h_norm.shape[-1]))(z_id_beta) |
|
|
|
i = Multiply()([h_norm, z_id_gamma]) |
|
i = Add()([i, z_id_beta]) |
|
|
|
h_out = AdaptiveAttention()([m, a, i]) |
|
|
|
return h_out |
|
|
|
|
|
def aad_block_mod(inputs, c_out): |
|
h, z_att, z_id = inputs |
|
|
|
h_norm = BatchNormalization()(h) |
|
h = Conv2D(filters=c_out, kernel_size=1, kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(h_norm) |
|
|
|
m = Activation('sigmoid')(h) |
|
|
|
z_att_gamma = Conv2D(filters=c_out, |
|
kernel_size=1, |
|
kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(z_att) |
|
|
|
z_att_beta = Conv2D(filters=c_out, |
|
kernel_size=1, |
|
kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(z_att) |
|
|
|
a = Multiply()([h_norm, z_att_gamma]) |
|
a = Add()([a, z_att_beta]) |
|
|
|
z_id_gamma = Dense(h_norm.shape[-1], |
|
kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(z_id) |
|
|
|
i = Conv2DMod(filters=c_out, |
|
kernel_size=1, |
|
padding='same', |
|
kernel_initializer='he_uniform', |
|
kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))([h_norm, z_id_gamma]) |
|
|
|
h_out = AdaptiveAttention()([m, a, i]) |
|
|
|
return h_out |
|
|
|
|
|
def aad_res_block(inputs, c_in, c_out): |
|
h, z_att, z_id = inputs |
|
|
|
if c_in == c_out: |
|
aad = aad_block([h, z_att, z_id], c_out) |
|
act = ReLU()(aad) |
|
conv = Conv2D(filters=c_out, |
|
kernel_size=3, |
|
padding='same', |
|
kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) |
|
|
|
aad = aad_block([conv, z_att, z_id], c_out) |
|
act = ReLU()(aad) |
|
conv = Conv2D(filters=c_out, |
|
kernel_size=3, |
|
padding='same', |
|
kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) |
|
|
|
h_out = Add()([h, conv]) |
|
return h_out |
|
else: |
|
aad = aad_block([h, z_att, z_id], c_in) |
|
act = ReLU()(aad) |
|
h_res = Conv2D(filters=c_out, |
|
kernel_size=3, |
|
padding='same', |
|
kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) |
|
|
|
aad = aad_block([h, z_att, z_id], c_in) |
|
act = ReLU()(aad) |
|
conv = Conv2D(filters=c_out, |
|
kernel_size=3, |
|
padding='same', |
|
kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) |
|
|
|
aad = aad_block([conv, z_att, z_id], c_out) |
|
act = ReLU()(aad) |
|
conv = Conv2D(filters=c_out, |
|
kernel_size=3, |
|
padding='same', |
|
kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) |
|
|
|
h_out = Add()([h_res, conv]) |
|
|
|
return h_out |
|
|
|
|
|
def aad_res_block_mod(inputs, c_in, c_out): |
|
h, z_att, z_id = inputs |
|
|
|
if c_in == c_out: |
|
aad = aad_block_mod([h, z_att, z_id], c_out) |
|
|
|
act = ReLU()(aad) |
|
conv = Conv2D(filters=c_out, |
|
kernel_size=3, |
|
padding='same', |
|
kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) |
|
|
|
aad = aad_block_mod([conv, z_att, z_id], c_out) |
|
act = ReLU()(aad) |
|
conv = Conv2D(filters=c_out, |
|
kernel_size=3, |
|
padding='same', |
|
kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) |
|
|
|
h_out = Add()([h, conv]) |
|
return h_out |
|
else: |
|
aad = aad_block_mod([h, z_att, z_id], c_in) |
|
act = ReLU()(aad) |
|
h_res = Conv2D(filters=c_out, |
|
kernel_size=3, |
|
padding='same', |
|
kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) |
|
|
|
aad = aad_block_mod([h, z_att, z_id], c_in) |
|
act = ReLU()(aad) |
|
conv = Conv2D(filters=c_out, |
|
kernel_size=3, |
|
padding='same', |
|
kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) |
|
|
|
aad = aad_block_mod([conv, z_att, z_id], c_out) |
|
act = ReLU()(aad) |
|
conv = Conv2D(filters=c_out, |
|
kernel_size=3, |
|
padding='same', |
|
kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) |
|
|
|
h_out = Add()([h_res, conv]) |
|
|
|
return h_out |
|
|
|
|
|
class FilteredReLU(Layer): |
|
|
|
def __init__(self, |
|
critically_sampled, |
|
|
|
in_channels, |
|
out_channels, |
|
in_size, |
|
out_size, |
|
in_sampling_rate, |
|
out_sampling_rate, |
|
in_cutoff, |
|
out_cutoff, |
|
in_half_width, |
|
out_half_width, |
|
|
|
conv_kernel = 3, |
|
lrelu_upsampling = 2, |
|
filter_size = 6, |
|
conv_clamp = 256, |
|
use_radial_filters = False, |
|
is_torgb = False, |
|
**kwargs): |
|
super(FilteredReLU, self).__init__(**kwargs) |
|
self.critically_sampled = critically_sampled |
|
|
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.in_size = np.broadcast_to(np.asarray(in_size), [2]) |
|
self.out_size = np.broadcast_to(np.asarray(out_size), [2]) |
|
self.in_sampling_rate = in_sampling_rate |
|
self.out_sampling_rate = out_sampling_rate |
|
self.in_cutoff = in_cutoff |
|
self.out_cutoff = out_cutoff |
|
self.in_half_width = in_half_width |
|
self.out_half_width = out_half_width |
|
|
|
self.is_torgb = is_torgb |
|
|
|
self.conv_kernel = 1 if is_torgb else conv_kernel |
|
self.lrelu_upsampling = lrelu_upsampling |
|
self.conv_clamp = conv_clamp |
|
|
|
self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) |
|
|
|
|
|
self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) |
|
assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate |
|
self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 |
|
self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, |
|
cutoff=self.in_cutoff, |
|
width=self.in_half_width*2, |
|
fs=self.tmp_sampling_rate) |
|
|
|
|
|
self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) |
|
assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate |
|
self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 |
|
self.d_radial = use_radial_filters and not self.critically_sampled |
|
self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, |
|
cutoff=self.out_cutoff, |
|
width=self.out_half_width*2, |
|
fs=self.tmp_sampling_rate, |
|
radial=self.d_radial) |
|
|
|
pad_total = (self.out_size - 1) * self.d_factor + 1 |
|
pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor |
|
pad_total += self.u_taps + self.d_taps - 2 |
|
pad_lo = (pad_total + self.u_factor) // 2 |
|
pad_hi = pad_total - pad_lo |
|
self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] |
|
|
|
self.gain = 1 if self.is_torgb else np.sqrt(2) |
|
self.slope = 1 if self.is_torgb else 0.2 |
|
|
|
self.act_funcs = {'linear': |
|
{'func': lambda x, **_: x, |
|
'def_alpha': 0, |
|
'def_gain': 1}, |
|
'lrelu': |
|
{'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), |
|
'def_alpha': 0.2, |
|
'def_gain': np.sqrt(2)}, |
|
} |
|
|
|
b_init = tf.zeros_initializer() |
|
self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), |
|
dtype="float32"), |
|
trainable=True) |
|
|
|
def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): |
|
if numtaps == 1: |
|
return None |
|
|
|
if not radial: |
|
f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) |
|
return f |
|
|
|
x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs |
|
r = np.hypot(*np.meshgrid(x, x)) |
|
f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) |
|
beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) |
|
w = np.kaiser(numtaps, beta) |
|
f *= np.outer(w, w) |
|
f /= np.sum(f) |
|
return f |
|
|
|
def get_filter_size(self, f): |
|
if f is None: |
|
return 1, 1 |
|
assert 1 <= f.ndim <= 2 |
|
return f.shape[-1], f.shape[0] |
|
|
|
def parse_padding(self, padding): |
|
if isinstance(padding, int): |
|
padding = [padding, padding] |
|
assert isinstance(padding, (list, tuple)) |
|
assert all(isinstance(x, (int, np.integer)) for x in padding) |
|
padding = [int(x) for x in padding] |
|
if len(padding) == 2: |
|
px, py = padding |
|
padding = [px, px, py, py] |
|
px0, px1, py0, py1 = padding |
|
return px0, px1, py0, py1 |
|
|
|
def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): |
|
spec = self.act_funcs[act] |
|
alpha = float(alpha if alpha is not None else spec['def_alpha']) |
|
gain = float(gain if gain is not None else spec['def_gain']) |
|
clamp = float(clamp if clamp is not None else -1) |
|
|
|
if b is not None: |
|
x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) |
|
x = spec['func'](x, alpha=alpha) |
|
|
|
if gain != 1: |
|
x = x * gain |
|
|
|
if clamp >= 0: |
|
x = tf.clip_by_value(x, -clamp, clamp) |
|
return x |
|
|
|
def parse_scaling(self, scaling): |
|
if isinstance(scaling, int): |
|
scaling = [scaling, scaling] |
|
sx, sy = scaling |
|
assert sx >= 1 and sy >= 1 |
|
return sx, sy |
|
|
|
def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): |
|
if f is None: |
|
f = tf.ones([1, 1], dtype=tf.float32) |
|
|
|
batch_size, in_height, in_width, num_channels = x.shape |
|
|
|
upx, upy = self.parse_scaling(up) |
|
downx, downy = self.parse_scaling(down) |
|
padx0, padx1, pady0, pady1 = self.parse_padding(padding) |
|
|
|
upW = in_width * upx + padx0 + padx1 |
|
upH = in_height * upy + pady0 + pady1 |
|
assert upW >= f.shape[-1] and upH >= f.shape[0] |
|
|
|
|
|
x = tf.transpose(x, perm=[0, 3, 1, 2]) |
|
|
|
|
|
x = tf.reshape(x, [num_channels, batch_size, in_height, 1, in_width, 1]) |
|
x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) |
|
x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) |
|
|
|
|
|
x = tf.pad(x, [[0, 0], [0, 0], |
|
[tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], |
|
[tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) |
|
x = x[:, :, |
|
tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), |
|
tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] |
|
|
|
|
|
f = f * (gain ** (f.ndim / 2)) |
|
f = tf.cast(f, dtype=x.dtype) |
|
if not flip_filter: |
|
f = tf.reverse(f, axis=[-1]) |
|
f = tf.reshape(f, shape=(1, 1, f.shape[-1])) |
|
f = tf.repeat(f, repeats=num_channels, axis=0) |
|
|
|
if tf.rank(f) == 4: |
|
f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) |
|
x = tf.nn.conv2d(x, f_0, 1, 'VALID') |
|
else: |
|
f_0 = tf.expand_dims(f, axis=2) |
|
f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) |
|
|
|
f_1 = tf.expand_dims(f, axis=3) |
|
f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) |
|
|
|
x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') |
|
x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') |
|
|
|
x = x[:, :, ::downy, ::downx] |
|
|
|
|
|
x = tf.transpose(x, perm=[0, 2, 3, 1]) |
|
return x |
|
|
|
|
|
def filtered_lrelu(self, |
|
x, fu=None, fd=None, b=None, |
|
up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): |
|
|
|
|
|
|
|
px0, px1, py0, py1 = self.parse_padding(padding) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.bias_act(x=x, b=b) |
|
x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) |
|
x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) |
|
x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) |
|
|
|
return x |
|
|
|
|
|
def call(self, inputs): |
|
return self.filtered_lrelu(inputs, |
|
fu=self.u_filter, |
|
fd=self.d_filter, |
|
b=self.bias, |
|
up=self.u_factor, |
|
down=self.d_factor, |
|
padding=self.padding, |
|
gain=self.gain, |
|
slope=self.slope, |
|
clamp=self.conv_clamp) |
|
|
|
def get_config(self): |
|
base_config = super(FilteredReLU, self).get_config() |
|
return base_config |
|
|
|
|
|
class SynthesisLayer(Layer): |
|
|
|
def __init__(self, |
|
critically_sampled, |
|
|
|
in_channels, |
|
out_channels, |
|
in_size, |
|
out_size, |
|
in_sampling_rate, |
|
out_sampling_rate, |
|
in_cutoff, |
|
out_cutoff, |
|
in_half_width, |
|
out_half_width, |
|
|
|
conv_kernel = 3, |
|
lrelu_upsampling = 2, |
|
filter_size = 6, |
|
conv_clamp = 256, |
|
use_radial_filters = False, |
|
is_torgb = False, |
|
**kwargs): |
|
super(SynthesisLayer, self).__init__(**kwargs) |
|
self.critically_sampled = critically_sampled |
|
|
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.in_size = np.broadcast_to(np.asarray(in_size), [2]) |
|
self.out_size = np.broadcast_to(np.asarray(out_size), [2]) |
|
self.in_sampling_rate = in_sampling_rate |
|
self.out_sampling_rate = out_sampling_rate |
|
self.in_cutoff = in_cutoff |
|
self.out_cutoff = out_cutoff |
|
self.in_half_width = in_half_width |
|
self.out_half_width = out_half_width |
|
|
|
self.is_torgb = is_torgb |
|
|
|
self.conv_kernel = 1 if is_torgb else conv_kernel |
|
self.lrelu_upsampling = lrelu_upsampling |
|
self.conv_clamp = conv_clamp |
|
|
|
self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) |
|
|
|
|
|
self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) |
|
assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate |
|
self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 |
|
self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, |
|
cutoff=self.in_cutoff, |
|
width=self.in_half_width*2, |
|
fs=self.tmp_sampling_rate) |
|
|
|
|
|
self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) |
|
assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate |
|
self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 |
|
self.d_radial = use_radial_filters and not self.critically_sampled |
|
self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, |
|
cutoff=self.out_cutoff, |
|
width=self.out_half_width*2, |
|
fs=self.tmp_sampling_rate, |
|
radial=self.d_radial) |
|
|
|
pad_total = (self.out_size - 1) * self.d_factor + 1 |
|
pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor |
|
pad_total += self.u_taps + self.d_taps - 2 |
|
pad_lo = (pad_total + self.u_factor) // 2 |
|
pad_hi = pad_total - pad_lo |
|
self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] |
|
|
|
self.gain = 1 if self.is_torgb else np.sqrt(2) |
|
self.slope = 1 if self.is_torgb else 0.2 |
|
|
|
self.act_funcs = {'linear': |
|
{'func': lambda x, **_: x, |
|
'def_alpha': 0, |
|
'def_gain': 1}, |
|
'lrelu': |
|
{'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), |
|
'def_alpha': 0.2, |
|
'def_gain': np.sqrt(2)}, |
|
} |
|
|
|
b_init = tf.zeros_initializer() |
|
self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), |
|
dtype="float32"), |
|
trainable=True) |
|
self.affine = Dense(self.in_channels) |
|
self.conv = Conv2DMod(self.out_channels, kernel_size=self.conv_kernel, padding='same') |
|
|
|
def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): |
|
if numtaps == 1: |
|
return None |
|
|
|
if not radial: |
|
f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) |
|
return f |
|
|
|
x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs |
|
r = np.hypot(*np.meshgrid(x, x)) |
|
f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) |
|
beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) |
|
w = np.kaiser(numtaps, beta) |
|
f *= np.outer(w, w) |
|
f /= np.sum(f) |
|
return f |
|
|
|
def get_filter_size(self, f): |
|
if f is None: |
|
return 1, 1 |
|
assert 1 <= f.ndim <= 2 |
|
return f.shape[-1], f.shape[0] |
|
|
|
def parse_padding(self, padding): |
|
if isinstance(padding, int): |
|
padding = [padding, padding] |
|
assert isinstance(padding, (list, tuple)) |
|
assert all(isinstance(x, (int, np.integer)) for x in padding) |
|
padding = [int(x) for x in padding] |
|
if len(padding) == 2: |
|
px, py = padding |
|
padding = [px, px, py, py] |
|
px0, px1, py0, py1 = padding |
|
return px0, px1, py0, py1 |
|
|
|
def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): |
|
spec = self.act_funcs[act] |
|
alpha = float(alpha if alpha is not None else spec['def_alpha']) |
|
gain = float(gain if gain is not None else spec['def_gain']) |
|
clamp = float(clamp if clamp is not None else -1) |
|
|
|
if b is not None: |
|
x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) |
|
x = spec['func'](x, alpha=alpha) |
|
|
|
if gain != 1: |
|
x = x * gain |
|
|
|
if clamp >= 0: |
|
x = tf.clip_by_value(x, -clamp, clamp) |
|
return x |
|
|
|
def parse_scaling(self, scaling): |
|
if isinstance(scaling, int): |
|
scaling = [scaling, scaling] |
|
sx, sy = scaling |
|
assert sx >= 1 and sy >= 1 |
|
return sx, sy |
|
|
|
def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): |
|
if f is None: |
|
f = tf.ones([1, 1], dtype=tf.float32) |
|
|
|
batch_size, in_height, in_width, num_channels = x.shape |
|
|
|
upx, upy = self.parse_scaling(up) |
|
downx, downy = self.parse_scaling(down) |
|
padx0, padx1, pady0, pady1 = self.parse_padding(padding) |
|
|
|
upW = in_width * upx + padx0 + padx1 |
|
upH = in_height * upy + pady0 + pady1 |
|
assert upW >= f.shape[-1] and upH >= f.shape[0] |
|
|
|
|
|
x = tf.transpose(x, perm=[0, 3, 1, 2]) |
|
|
|
|
|
x = tf.reshape(x, [num_channels, batch_size, in_height, 1, in_width, 1]) |
|
x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) |
|
x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) |
|
|
|
|
|
x = tf.pad(x, [[0, 0], [0, 0], |
|
[tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], |
|
[tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) |
|
x = x[:, :, tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] |
|
|
|
|
|
f = f * (gain ** (f.ndim / 2)) |
|
f = tf.cast(f, dtype=x.dtype) |
|
if not flip_filter: |
|
f = tf.reverse(f, axis=[-1]) |
|
f = tf.reshape(f, shape=(1, 1, f.shape[-1])) |
|
f = tf.repeat(f, repeats=num_channels, axis=0) |
|
|
|
if tf.rank(f) == 4: |
|
f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) |
|
x = tf.nn.conv2d(x, f_0, 1, 'VALID') |
|
else: |
|
f_0 = tf.expand_dims(f, axis=2) |
|
f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) |
|
|
|
f_1 = tf.expand_dims(f, axis=3) |
|
f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) |
|
|
|
x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') |
|
x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') |
|
|
|
x = x[:, :, ::downy, ::downx] |
|
|
|
|
|
x = tf.transpose(x, perm=[0, 2, 3, 1]) |
|
return x |
|
|
|
|
|
def filtered_lrelu(self, |
|
x, fu=None, fd=None, b=None, |
|
up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): |
|
|
|
|
|
|
|
px0, px1, py0, py1 = self.parse_padding(padding) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.bias_act(x=x, b=b) |
|
x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) |
|
x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) |
|
x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) |
|
|
|
return x |
|
|
|
|
|
def call(self, inputs): |
|
x, w = inputs |
|
styles = self.affine(w) |
|
x = self.conv([x, styles]) |
|
x = self.filtered_lrelu(x, |
|
fu=self.u_filter, |
|
fd=self.d_filter, |
|
b=self.bias, |
|
up=self.u_factor, |
|
down=self.d_factor, |
|
padding=self.padding, |
|
gain=self.gain, |
|
slope=self.slope, |
|
clamp=self.conv_clamp) |
|
return x |
|
|
|
def get_config(self): |
|
base_config = super(SynthesisLayer, self).get_config() |
|
return base_config |
|
|
|
|
|
class SynthesisLayerNoMod(Layer): |
|
|
|
def __init__(self, |
|
critically_sampled, |
|
|
|
in_channels, |
|
out_channels, |
|
in_size, |
|
out_size, |
|
in_sampling_rate, |
|
out_sampling_rate, |
|
in_cutoff, |
|
out_cutoff, |
|
in_half_width, |
|
out_half_width, |
|
|
|
conv_kernel = 3, |
|
lrelu_upsampling = 2, |
|
filter_size = 6, |
|
conv_clamp = 256, |
|
use_radial_filters = False, |
|
is_torgb = False, |
|
batch_size = 10, |
|
**kwargs): |
|
super(SynthesisLayerNoMod, self).__init__(**kwargs) |
|
self.critically_sampled = critically_sampled |
|
self.bs = batch_size |
|
|
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.in_size = np.broadcast_to(np.asarray(in_size), [2]) |
|
self.out_size = np.broadcast_to(np.asarray(out_size), [2]) |
|
self.in_sampling_rate = in_sampling_rate |
|
self.out_sampling_rate = out_sampling_rate |
|
self.in_cutoff = in_cutoff |
|
self.out_cutoff = out_cutoff |
|
self.in_half_width = in_half_width |
|
self.out_half_width = out_half_width |
|
|
|
self.is_torgb = is_torgb |
|
|
|
self.conv_kernel = 1 if is_torgb else conv_kernel |
|
self.lrelu_upsampling = lrelu_upsampling |
|
self.conv_clamp = conv_clamp |
|
|
|
self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) |
|
|
|
|
|
self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) |
|
assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate |
|
self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 |
|
self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, |
|
cutoff=self.in_cutoff, |
|
width=self.in_half_width*2, |
|
fs=self.tmp_sampling_rate) |
|
|
|
|
|
self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) |
|
assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate |
|
self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 |
|
self.d_radial = use_radial_filters and not self.critically_sampled |
|
self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, |
|
cutoff=self.out_cutoff, |
|
width=self.out_half_width*2, |
|
fs=self.tmp_sampling_rate, |
|
radial=self.d_radial) |
|
|
|
pad_total = (self.out_size - 1) * self.d_factor + 1 |
|
pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor |
|
pad_total += self.u_taps + self.d_taps - 2 |
|
pad_lo = (pad_total + self.u_factor) // 2 |
|
pad_hi = pad_total - pad_lo |
|
self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] |
|
|
|
self.gain = 1 if self.is_torgb else np.sqrt(2) |
|
self.slope = 1 if self.is_torgb else 0.2 |
|
|
|
self.act_funcs = {'linear': |
|
{'func': lambda x, **_: x, |
|
'def_alpha': 0, |
|
'def_gain': 1}, |
|
'lrelu': |
|
{'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), |
|
'def_alpha': 0.2, |
|
'def_gain': np.sqrt(2)}, |
|
} |
|
|
|
b_init = tf.zeros_initializer() |
|
self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), |
|
dtype="float32"), |
|
trainable=True) |
|
self.conv = Conv2D(self.out_channels, kernel_size=self.conv_kernel, padding='same') |
|
|
|
def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): |
|
if numtaps == 1: |
|
return None |
|
|
|
if not radial: |
|
f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) |
|
return f |
|
|
|
x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs |
|
r = np.hypot(*np.meshgrid(x, x)) |
|
f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) |
|
beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) |
|
w = np.kaiser(numtaps, beta) |
|
f *= np.outer(w, w) |
|
f /= np.sum(f) |
|
return f |
|
|
|
def get_filter_size(self, f): |
|
if f is None: |
|
return 1, 1 |
|
assert 1 <= f.ndim <= 2 |
|
return f.shape[-1], f.shape[0] |
|
|
|
def parse_padding(self, padding): |
|
if isinstance(padding, int): |
|
padding = [padding, padding] |
|
assert isinstance(padding, (list, tuple)) |
|
assert all(isinstance(x, (int, np.integer)) for x in padding) |
|
padding = [int(x) for x in padding] |
|
if len(padding) == 2: |
|
px, py = padding |
|
padding = [px, px, py, py] |
|
px0, px1, py0, py1 = padding |
|
return px0, px1, py0, py1 |
|
|
|
@tf.function |
|
def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): |
|
spec = self.act_funcs[act] |
|
alpha = float(alpha if alpha is not None else spec['def_alpha']) |
|
gain = float(gain if gain is not None else spec['def_gain']) |
|
clamp = float(clamp if clamp is not None else -1) |
|
|
|
if b is not None: |
|
x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) |
|
x = spec['func'](x, alpha=alpha) |
|
|
|
if gain != 1: |
|
x = x * gain |
|
|
|
if clamp >= 0: |
|
x = tf.clip_by_value(x, -clamp, clamp) |
|
return x |
|
|
|
def parse_scaling(self, scaling): |
|
if isinstance(scaling, int): |
|
scaling = [scaling, scaling] |
|
sx, sy = scaling |
|
assert sx >= 1 and sy >= 1 |
|
return sx, sy |
|
|
|
@tf.function |
|
def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): |
|
if f is None: |
|
f = tf.ones([1, 1], dtype=tf.float32) |
|
|
|
batch_size, in_height, in_width, num_channels = x.shape |
|
batch_size = tf.shape(x)[0] |
|
|
|
upx, upy = self.parse_scaling(up) |
|
downx, downy = self.parse_scaling(down) |
|
padx0, padx1, pady0, pady1 = self.parse_padding(padding) |
|
|
|
upW = in_width * upx + padx0 + padx1 |
|
upH = in_height * upy + pady0 + pady1 |
|
assert upW >= f.shape[-1] and upH >= f.shape[0] |
|
|
|
|
|
x = tf.transpose(x, perm=[0, 3, 1, 2]) |
|
|
|
x = tf.reshape(x, [batch_size, num_channels, in_height, 1, in_width, 1]) |
|
x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) |
|
x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) |
|
|
|
|
|
x = tf.pad(x, [[0, 0], [0, 0], |
|
[tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], |
|
[tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) |
|
x = x[:, :, |
|
tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), |
|
tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] |
|
|
|
|
|
f = f * (gain ** (tf.rank(f) / 2)) |
|
f = tf.cast(f, dtype=x.dtype) |
|
if not flip_filter: |
|
f = tf.reverse(f, axis=[-1]) |
|
f = tf.reshape(f, shape=(1, 1, f.shape[-1])) |
|
f = tf.repeat(f, repeats=num_channels, axis=0) |
|
|
|
|
|
|
|
|
|
|
|
f_0 = tf.expand_dims(f, axis=2) |
|
f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) |
|
|
|
f_1 = tf.expand_dims(f, axis=3) |
|
f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) |
|
|
|
x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') |
|
x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') |
|
|
|
x = x[:, :, ::downy, ::downx] |
|
|
|
|
|
x = tf.transpose(x, perm=[0, 2, 3, 1]) |
|
return x |
|
|
|
@tf.function |
|
def filtered_lrelu(self, |
|
x, fu=None, fd=None, b=None, |
|
up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): |
|
|
|
|
|
|
|
px0, px1, py0, py1 = self.parse_padding(padding) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.bias_act(x=x, b=b) |
|
x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) |
|
x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) |
|
x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) |
|
|
|
return x |
|
|
|
|
|
def call(self, inputs): |
|
x = inputs |
|
x = self.conv(x) |
|
x = self.filtered_lrelu(x, |
|
fu=self.u_filter, |
|
fd=self.d_filter, |
|
b=self.bias, |
|
up=self.u_factor, |
|
down=self.d_factor, |
|
padding=self.padding, |
|
gain=self.gain, |
|
slope=self.slope, |
|
clamp=self.conv_clamp) |
|
return x |
|
|
|
def get_config(self): |
|
base_config = super(SynthesisLayer, self).get_config() |
|
return base_config |
|
|
|
|
|
class SynthesisLayerNoModBN(Layer): |
|
|
|
def __init__(self, |
|
critically_sampled, |
|
|
|
in_channels, |
|
out_channels, |
|
in_size, |
|
out_size, |
|
in_sampling_rate, |
|
out_sampling_rate, |
|
in_cutoff, |
|
out_cutoff, |
|
in_half_width, |
|
out_half_width, |
|
|
|
conv_kernel = 3, |
|
lrelu_upsampling = 2, |
|
filter_size = 6, |
|
conv_clamp = 256, |
|
use_radial_filters = False, |
|
is_torgb = False, |
|
batch_size = 10, |
|
**kwargs): |
|
super(SynthesisLayerNoModBN, self).__init__(**kwargs) |
|
self.critically_sampled = critically_sampled |
|
self.bs = batch_size |
|
|
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.in_size = np.broadcast_to(np.asarray(in_size), [2]) |
|
self.out_size = np.broadcast_to(np.asarray(out_size), [2]) |
|
self.in_sampling_rate = in_sampling_rate |
|
self.out_sampling_rate = out_sampling_rate |
|
self.in_cutoff = in_cutoff |
|
self.out_cutoff = out_cutoff |
|
self.in_half_width = in_half_width |
|
self.out_half_width = out_half_width |
|
|
|
self.is_torgb = is_torgb |
|
|
|
self.conv_kernel = 1 if is_torgb else conv_kernel |
|
self.lrelu_upsampling = lrelu_upsampling |
|
self.conv_clamp = conv_clamp |
|
|
|
self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1.0 if is_torgb else lrelu_upsampling) |
|
|
|
|
|
self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) |
|
assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate |
|
self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 |
|
self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, |
|
cutoff=self.in_cutoff, |
|
width=self.in_half_width*2, |
|
fs=self.tmp_sampling_rate) |
|
|
|
|
|
self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) |
|
assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate |
|
self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 |
|
self.d_radial = use_radial_filters and not self.critically_sampled |
|
self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, |
|
cutoff=self.out_cutoff, |
|
width=self.out_half_width*2, |
|
fs=self.tmp_sampling_rate, |
|
radial=self.d_radial) |
|
|
|
pad_total = (self.out_size - 1) * self.d_factor + 1 |
|
pad_total -= (self.in_size) * self.u_factor |
|
pad_total += self.u_taps + self.d_taps - 2 |
|
pad_lo = (pad_total + self.u_factor) // 2 |
|
pad_hi = pad_total - pad_lo |
|
self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] |
|
|
|
self.gain = 1 if self.is_torgb else np.sqrt(2) |
|
self.slope = 1 if self.is_torgb else 0.2 |
|
|
|
self.act_funcs = {'linear': |
|
{'func': lambda x, **_: x, |
|
'def_alpha': 0, |
|
'def_gain': 1}, |
|
'lrelu': |
|
{'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), |
|
'def_alpha': 0.2, |
|
'def_gain': np.sqrt(2)}, |
|
} |
|
|
|
b_init = tf.zeros_initializer() |
|
self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), |
|
dtype="float32"), |
|
trainable=True) |
|
self.conv = Conv2D(self.out_channels, kernel_size=self.conv_kernel, padding='same') |
|
self.bn = BatchNormalization() |
|
|
|
def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): |
|
if numtaps == 1: |
|
return None |
|
|
|
if not radial: |
|
f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) |
|
return f |
|
|
|
x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs |
|
r = np.hypot(*np.meshgrid(x, x)) |
|
f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) |
|
beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) |
|
w = np.kaiser(numtaps, beta) |
|
f *= np.outer(w, w) |
|
f /= np.sum(f) |
|
return f |
|
|
|
def get_filter_size(self, f): |
|
if f is None: |
|
return 1, 1 |
|
assert 1 <= f.ndim <= 2 |
|
return f.shape[-1], f.shape[0] |
|
|
|
def parse_padding(self, padding): |
|
if isinstance(padding, int): |
|
padding = [padding, padding] |
|
assert isinstance(padding, (list, tuple)) |
|
assert all(isinstance(x, (int, np.integer)) for x in padding) |
|
padding = [int(x) for x in padding] |
|
if len(padding) == 2: |
|
px, py = padding |
|
padding = [px, px, py, py] |
|
px0, px1, py0, py1 = padding |
|
return px0, px1, py0, py1 |
|
|
|
@tf.function |
|
def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): |
|
spec = self.act_funcs[act] |
|
alpha = float(alpha if alpha is not None else spec['def_alpha']) |
|
gain = tf.cast(gain if gain is not None else spec['def_gain'], tf.float32) |
|
clamp = float(clamp if clamp is not None else -1) |
|
|
|
if b is not None: |
|
x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) |
|
x = spec['func'](x, alpha=alpha) |
|
|
|
x = x * gain |
|
|
|
if clamp >= 0: |
|
x = tf.clip_by_value(x, -clamp, clamp) |
|
return x |
|
|
|
def parse_scaling(self, scaling): |
|
if isinstance(scaling, int): |
|
scaling = [scaling, scaling] |
|
sx, sy = scaling |
|
assert sx >= 1 and sy >= 1 |
|
return sx, sy |
|
|
|
@tf.function |
|
def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): |
|
if f is None: |
|
f = tf.ones([1, 1], dtype=tf.float32) |
|
|
|
batch_size, in_height, in_width, num_channels = x.shape |
|
batch_size = tf.shape(x)[0] |
|
|
|
upx, upy = self.parse_scaling(up) |
|
downx, downy = self.parse_scaling(down) |
|
padx0, padx1, pady0, pady1 = self.parse_padding(padding) |
|
|
|
upW = in_width * upx + padx0 + padx1 |
|
upH = in_height * upy + pady0 + pady1 |
|
assert upW >= f.shape[-1] and upH >= f.shape[0] |
|
|
|
|
|
x = tf.transpose(x, perm=[0, 3, 1, 2]) |
|
|
|
x = tf.reshape(x, [batch_size, num_channels, in_height, 1, in_width, 1]) |
|
x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) |
|
x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) |
|
|
|
|
|
x = tf.pad(x, [[0, 0], [0, 0], |
|
[tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], |
|
[tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) |
|
x = x[:, :, |
|
tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), |
|
tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] |
|
|
|
|
|
f = f * (gain ** (tf.rank(f) / 2)) |
|
f = tf.cast(f, dtype=x.dtype) |
|
if not flip_filter: |
|
f = tf.reverse(f, axis=[-1]) |
|
f = tf.reshape(f, shape=(1, 1, f.shape[-1])) |
|
f = tf.repeat(f, repeats=num_channels, axis=0) |
|
|
|
|
|
|
|
|
|
|
|
f_0 = tf.expand_dims(f, axis=2) |
|
f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) |
|
|
|
f_1 = tf.expand_dims(f, axis=3) |
|
f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) |
|
|
|
x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') |
|
x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') |
|
|
|
x = x[:, :, ::downy, ::downx] |
|
|
|
|
|
x = tf.transpose(x, perm=[0, 2, 3, 1]) |
|
return x |
|
|
|
@tf.function |
|
def filtered_lrelu(self, |
|
x, fu=None, fd=None, b=None, |
|
up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): |
|
|
|
|
|
|
|
px0, px1, py0, py1 = self.parse_padding(padding) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.bias_act(x=x, b=b) |
|
x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) |
|
x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) |
|
x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) |
|
|
|
return x |
|
|
|
|
|
def call(self, inputs): |
|
x = inputs |
|
x = self.conv(x) |
|
x = self.bn(x) |
|
x = self.filtered_lrelu(x, |
|
fu=self.u_filter, |
|
fd=self.d_filter, |
|
b=self.bias, |
|
up=self.u_factor, |
|
down=self.d_factor, |
|
padding=self.padding, |
|
gain=self.gain, |
|
slope=self.slope, |
|
clamp=self.conv_clamp) |
|
return x |
|
|
|
def get_config(self): |
|
base_config = super(SynthesisLayerNoModBN, self).get_config() |
|
return base_config |
|
|
|
|
|
class SynthesisInput(Layer): |
|
def __init__(self, |
|
w_dim, |
|
channels, |
|
size, |
|
sampling_rate, |
|
bandwidth, |
|
**kwargs): |
|
super(SynthesisInput, self).__init__(**kwargs) |
|
self.w_dim = w_dim |
|
self.channels = channels |
|
self.size = np.broadcast_to(np.asarray(size), [2]) |
|
self.sampling_rate = sampling_rate |
|
self.bandwidth = bandwidth |
|
|
|
|
|
freqs = np.random.normal(size=(int(channels[0]), 2)) |
|
radii = np.sqrt(np.sum(np.square(freqs), axis=1, keepdims=True)) |
|
freqs /= radii * np.power(np.exp(np.square(radii)), 0.25) |
|
freqs *= bandwidth |
|
phases = np.random.uniform(size=[int(channels[0])]) - 0.5 |
|
|
|
|
|
w_init = tf.random_normal_initializer() |
|
self.weight = tf.Variable(initial_value=w_init(shape=(self.channels, self.channels), |
|
dtype="float32"), rainable=True) |
|
self.affine = Dense(4, kernel_initializer=tf.zeros_initializer, bias_initializer=tf.zeros_initializer) |
|
self.transform = tf.eye(3, 3) |
|
self.freqs = tf.constant(freqs) |
|
self.phases = tf.constant(phases) |
|
|
|
def call(self, w): |
|
|
|
transforms = tf.expand_dims(self.transform, axis=0) |
|
freqs = tf.expand_dims(self.freqs, axis=0) |
|
phases = tf.expand_dims(self.phases, axis=0) |
|
|
|
|
|
t = self.affine(w) |
|
t = t / tf.linalg.norm(t[:, :2], axis=1, keepdims=True) |
|
|
|
m_r = tf.repeat(tf.expand_dims(tf.eye(3), axis=0), repeats=w.shape[0], axis=0) |
|
m_r[:, 0, 0] = t[:, 0] |
|
m_r[:, 0, 1] = -t[:, 1] |
|
m_r[:, 1, 0] = t[:, 1] |
|
m_r[:, 1, 1] = t[:, 0] |
|
|
|
m_t = tf.repeat(tf.expand_dims(tf.eye(3), axis=0), repeats=w.shape[0], axis=0) |
|
m_t[:, 0, 2] = -t[:, 2] |
|
m_t[:, 1, 2] = -t[:, 3] |
|
transforms = m_r @ m_t @ transforms |
|
|
|
|
|
phases = phases + tf.expand_dims(freqs @ transforms[:, :2, 2:], axis=2) |
|
freqs = freqs @ transforms[:, :2, :2] |
|
|
|
|
|
amplitudes = tf.clip_by_value(1 - (tf.linalg.norm(freqs, axis=1, keepdims=True) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth), 0, 1) |
|
|
|
|
|
|
|
def get_config(self): |
|
base_config = super(SynthesisInput, self).get_config() |
|
return base_config |
|
|
|
|
|
class SynthesisLayerFS(Layer): |
|
|
|
def __init__(self, |
|
critically_sampled, |
|
|
|
in_channels, |
|
out_channels, |
|
in_size, |
|
out_size, |
|
in_sampling_rate, |
|
out_sampling_rate, |
|
in_cutoff, |
|
out_cutoff, |
|
in_half_width, |
|
out_half_width, |
|
|
|
conv_kernel = 3, |
|
lrelu_upsampling = 2, |
|
filter_size = 6, |
|
conv_clamp = 256, |
|
use_radial_filters = False, |
|
is_torgb = False, |
|
**kwargs): |
|
super(SynthesisLayerFS, self).__init__(**kwargs) |
|
self.critically_sampled = critically_sampled |
|
|
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.in_size = np.broadcast_to(np.asarray(in_size), [2]) |
|
self.out_size = np.broadcast_to(np.asarray(out_size), [2]) |
|
self.in_sampling_rate = in_sampling_rate |
|
self.out_sampling_rate = out_sampling_rate |
|
self.in_cutoff = in_cutoff |
|
self.out_cutoff = out_cutoff |
|
self.in_half_width = in_half_width |
|
self.out_half_width = out_half_width |
|
|
|
self.is_torgb = is_torgb |
|
|
|
self.conv_kernel = 1 if is_torgb else conv_kernel |
|
self.lrelu_upsampling = lrelu_upsampling |
|
self.conv_clamp = conv_clamp |
|
|
|
self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) |
|
|
|
|
|
self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) |
|
assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate |
|
self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 |
|
self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, |
|
cutoff=self.in_cutoff, |
|
width=self.in_half_width*2, |
|
fs=self.tmp_sampling_rate) |
|
|
|
|
|
self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) |
|
assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate |
|
self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 |
|
self.d_radial = use_radial_filters and not self.critically_sampled |
|
self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, |
|
cutoff=self.out_cutoff, |
|
width=self.out_half_width*2, |
|
fs=self.tmp_sampling_rate, |
|
radial=self.d_radial) |
|
|
|
pad_total = (self.out_size - 1) * self.d_factor + 1 |
|
pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor |
|
pad_total += self.u_taps + self.d_taps - 2 |
|
pad_lo = (pad_total + self.u_factor) // 2 |
|
pad_hi = pad_total - pad_lo |
|
self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] |
|
|
|
self.gain = 1 if self.is_torgb else np.sqrt(2) |
|
self.slope = 1 if self.is_torgb else 0.2 |
|
|
|
self.act_funcs = {'linear': |
|
{'func': lambda x, **_: x, |
|
'def_alpha': 0, |
|
'def_gain': 1}, |
|
'lrelu': |
|
{'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), |
|
'def_alpha': 0.2, |
|
'def_gain': np.sqrt(2)}, |
|
} |
|
|
|
b_init = tf.zeros_initializer() |
|
self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), |
|
dtype="float32"), |
|
trainable=True) |
|
self.affine = Dense(self.out_channels) |
|
self.conv_mod = Conv2DMod(self.out_channels, kernel_size=self.conv_kernel, padding='same') |
|
self.bn = BatchNormalization() |
|
self.conv_gamma = Conv2D(self.out_channels, kernel_size=1) |
|
self.conv_beta = Conv2D(self.out_channels, kernel_size=1) |
|
self.conv_gate = Conv2D(self.out_channels, kernel_size=1) |
|
self.conv_final = Conv2D(self.out_channels, kernel_size=self.conv_kernel, padding='same') |
|
|
|
|
|
def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): |
|
if numtaps == 1: |
|
return None |
|
|
|
if not radial: |
|
f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) |
|
return f |
|
|
|
x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs |
|
r = np.hypot(*np.meshgrid(x, x)) |
|
f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) |
|
beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) |
|
w = np.kaiser(numtaps, beta) |
|
f *= np.outer(w, w) |
|
f /= np.sum(f) |
|
return f |
|
|
|
def get_filter_size(self, f): |
|
if f is None: |
|
return 1, 1 |
|
assert 1 <= f.ndim <= 2 |
|
return f.shape[-1], f.shape[0] |
|
|
|
def parse_padding(self, padding): |
|
if isinstance(padding, int): |
|
padding = [padding, padding] |
|
assert isinstance(padding, (list, tuple)) |
|
assert all(isinstance(x, (int, np.integer)) for x in padding) |
|
padding = [int(x) for x in padding] |
|
if len(padding) == 2: |
|
px, py = padding |
|
padding = [px, px, py, py] |
|
px0, px1, py0, py1 = padding |
|
return px0, px1, py0, py1 |
|
|
|
@tf.function |
|
def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): |
|
spec = self.act_funcs[act] |
|
alpha = float(alpha if alpha is not None else spec['def_alpha']) |
|
gain = tf.cast(gain if gain is not None else spec['def_gain'], tf.float32) |
|
clamp = float(clamp if clamp is not None else -1) |
|
|
|
if b is not None: |
|
x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) |
|
x = spec['func'](x, alpha=alpha) |
|
|
|
x = x * gain |
|
|
|
if clamp >= 0: |
|
x = tf.clip_by_value(x, -clamp, clamp) |
|
return x |
|
|
|
def parse_scaling(self, scaling): |
|
if isinstance(scaling, int): |
|
scaling = [scaling, scaling] |
|
sx, sy = scaling |
|
assert sx >= 1 and sy >= 1 |
|
return sx, sy |
|
|
|
@tf.function |
|
def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): |
|
if f is None: |
|
f = tf.ones([1, 1], dtype=tf.float32) |
|
|
|
batch_size, in_height, in_width, num_channels = x.shape |
|
batch_size = tf.shape(x)[0] |
|
|
|
upx, upy = self.parse_scaling(up) |
|
downx, downy = self.parse_scaling(down) |
|
padx0, padx1, pady0, pady1 = self.parse_padding(padding) |
|
|
|
upW = in_width * upx + padx0 + padx1 |
|
upH = in_height * upy + pady0 + pady1 |
|
assert upW >= f.shape[-1] and upH >= f.shape[0] |
|
|
|
|
|
x = tf.transpose(x, perm=[0, 3, 1, 2]) |
|
|
|
x = tf.reshape(x, [batch_size, num_channels, in_height, 1, in_width, 1]) |
|
x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) |
|
x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) |
|
|
|
|
|
x = tf.pad(x, [[0, 0], [0, 0], |
|
[tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], |
|
[tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) |
|
x = x[:, :, |
|
tf.math.maximum(-pady0, 0): x.shape[2] - tf.math.maximum(-pady1, 0), |
|
tf.math.maximum(-padx0, 0): x.shape[3] - tf.math.maximum(-padx1, 0)] |
|
|
|
|
|
f = f * (gain ** (tf.rank(f) / 2)) |
|
f = tf.cast(f, dtype=x.dtype) |
|
if not flip_filter: |
|
f = tf.reverse(f, axis=[-1]) |
|
f = tf.reshape(f, shape=(1, 1, f.shape[-1])) |
|
f = tf.repeat(f, repeats=num_channels, axis=0) |
|
|
|
|
|
|
|
|
|
|
|
f_0 = tf.expand_dims(f, axis=2) |
|
f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) |
|
|
|
f_1 = tf.expand_dims(f, axis=3) |
|
f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) |
|
|
|
x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') |
|
x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') |
|
|
|
x = x[:, :, ::downy, ::downx] |
|
|
|
|
|
x = tf.transpose(x, perm=[0, 2, 3, 1]) |
|
return x |
|
|
|
@tf.function |
|
def filtered_lrelu(self, |
|
x, fu=None, fd=None, b=None, |
|
up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): |
|
|
|
|
|
|
|
px0, px1, py0, py1 = self.parse_padding(padding) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.bias_act(x=x, b=b) |
|
x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) |
|
x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) |
|
x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) |
|
|
|
return x |
|
|
|
@tf.function |
|
def aadm(self, x, w, a): |
|
w_affine = self.affine(w) |
|
x_norm = self.bn(x) |
|
|
|
x_id = self.conv_mod([x_norm, w_affine]) |
|
|
|
gate = self.conv_gate(x_norm) |
|
gate = tf.nn.sigmoid(gate) |
|
|
|
x_att_beta = self.conv_beta(a) |
|
x_att_gamma = self.conv_gamma(a) |
|
|
|
x_att = x_norm * x_att_beta + x_att_gamma |
|
|
|
h = x_id * gate + (1 - gate) * x_att |
|
|
|
return h |
|
|
|
|
|
def call(self, inputs): |
|
x, w, a = inputs |
|
x = self.conv_final(x) |
|
x = self.aadm(x, w, a) |
|
x = self.filtered_lrelu(x, |
|
fu=self.u_filter, |
|
fd=self.d_filter, |
|
b=self.bias, |
|
up=self.u_factor, |
|
down=self.d_factor, |
|
padding=self.padding, |
|
gain=self.gain, |
|
slope=self.slope, |
|
clamp=self.conv_clamp) |
|
return x |
|
|
|
def get_config(self): |
|
base_config = super(SynthesisLayerFS, self).get_config() |
|
return base_config |
|
|
|
|
|
class SynthesisLayerUpDownOnly(Layer): |
|
|
|
def __init__(self, |
|
critically_sampled, |
|
|
|
in_channels, |
|
out_channels, |
|
in_size, |
|
out_size, |
|
in_sampling_rate, |
|
out_sampling_rate, |
|
in_cutoff, |
|
out_cutoff, |
|
in_half_width, |
|
out_half_width, |
|
|
|
conv_kernel = 3, |
|
lrelu_upsampling = 2, |
|
filter_size = 6, |
|
conv_clamp = 256, |
|
use_radial_filters = False, |
|
is_torgb = False, |
|
**kwargs): |
|
super(SynthesisLayerUpDownOnly, self).__init__(**kwargs) |
|
self.critically_sampled = critically_sampled |
|
|
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.in_size = np.broadcast_to(np.asarray(in_size), [2]) |
|
self.out_size = np.broadcast_to(np.asarray(out_size), [2]) |
|
self.in_sampling_rate = in_sampling_rate |
|
self.out_sampling_rate = out_sampling_rate |
|
self.in_cutoff = in_cutoff |
|
self.out_cutoff = out_cutoff |
|
self.in_half_width = in_half_width |
|
self.out_half_width = out_half_width |
|
|
|
self.is_torgb = is_torgb |
|
|
|
self.conv_kernel = 1 if is_torgb else conv_kernel |
|
self.lrelu_upsampling = lrelu_upsampling |
|
self.conv_clamp = conv_clamp |
|
|
|
self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) |
|
|
|
|
|
self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) |
|
assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate |
|
self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 |
|
self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, |
|
cutoff=self.in_cutoff, |
|
width=self.in_half_width*2, |
|
fs=self.tmp_sampling_rate) |
|
|
|
|
|
self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) |
|
assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate |
|
self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 |
|
self.d_radial = use_radial_filters and not self.critically_sampled |
|
self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, |
|
cutoff=self.out_cutoff, |
|
width=self.out_half_width*2, |
|
fs=self.tmp_sampling_rate, |
|
radial=self.d_radial) |
|
|
|
pad_total = (self.out_size - 1) * self.d_factor + 1 |
|
pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor |
|
pad_total += self.u_taps + self.d_taps - 2 |
|
pad_lo = (pad_total + self.u_factor) // 2 |
|
pad_hi = pad_total - pad_lo |
|
self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] |
|
|
|
self.gain = 1 if self.is_torgb else np.sqrt(2) |
|
self.slope = 1 if self.is_torgb else 0.2 |
|
|
|
self.act_funcs = {'linear': |
|
{'func': lambda x, **_: x, |
|
'def_alpha': 0, |
|
'def_gain': 1}, |
|
'lrelu': |
|
{'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), |
|
'def_alpha': 0.2, |
|
'def_gain': np.sqrt(2)}, |
|
} |
|
|
|
def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): |
|
if numtaps == 1: |
|
return None |
|
|
|
if not radial: |
|
f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) |
|
return f |
|
|
|
x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs |
|
r = np.hypot(*np.meshgrid(x, x)) |
|
f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) |
|
beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) |
|
w = np.kaiser(numtaps, beta) |
|
f *= np.outer(w, w) |
|
f /= np.sum(f) |
|
return f |
|
|
|
def get_filter_size(self, f): |
|
if f is None: |
|
return 1, 1 |
|
assert 1 <= f.ndim <= 2 |
|
return f.shape[-1], f.shape[0] |
|
|
|
def parse_padding(self, padding): |
|
if isinstance(padding, int): |
|
padding = [padding, padding] |
|
assert isinstance(padding, (list, tuple)) |
|
assert all(isinstance(x, (int, np.integer)) for x in padding) |
|
padding = [int(x) for x in padding] |
|
if len(padding) == 2: |
|
px, py = padding |
|
padding = [px, px, py, py] |
|
px0, px1, py0, py1 = padding |
|
return px0, px1, py0, py1 |
|
|
|
def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): |
|
spec = self.act_funcs[act] |
|
alpha = float(alpha if alpha is not None else spec['def_alpha']) |
|
gain = float(gain if gain is not None else spec['def_gain']) |
|
clamp = float(clamp if clamp is not None else -1) |
|
|
|
if b is not None: |
|
x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) |
|
x = spec['func'](x, alpha=alpha) |
|
|
|
if gain != 1: |
|
x = x * gain |
|
|
|
if clamp >= 0: |
|
x = tf.clip_by_value(x, -clamp, clamp) |
|
return x |
|
|
|
def parse_scaling(self, scaling): |
|
if isinstance(scaling, int): |
|
scaling = [scaling, scaling] |
|
sx, sy = scaling |
|
assert sx >= 1 and sy >= 1 |
|
return sx, sy |
|
|
|
def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): |
|
if f is None: |
|
f = tf.ones([1, 1], dtype=tf.float32) |
|
|
|
batch_size, in_height, in_width, num_channels = x.shape |
|
|
|
upx, upy = self.parse_scaling(up) |
|
downx, downy = self.parse_scaling(down) |
|
padx0, padx1, pady0, pady1 = self.parse_padding(padding) |
|
|
|
upW = in_width * upx + padx0 + padx1 |
|
upH = in_height * upy + pady0 + pady1 |
|
assert upW >= f.shape[-1] and upH >= f.shape[0] |
|
|
|
|
|
x = tf.transpose(x, perm=[0, 3, 1, 2]) |
|
|
|
|
|
x = tf.reshape(x, [num_channels, batch_size, in_height, 1, in_width, 1]) |
|
x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) |
|
x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) |
|
|
|
|
|
x = tf.pad(x, [[0, 0], [0, 0], |
|
[tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], |
|
[tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) |
|
x = x[:, :, tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] |
|
|
|
|
|
f = f * (gain ** (f.ndim / 2)) |
|
f = tf.cast(f, dtype=x.dtype) |
|
if not flip_filter: |
|
f = tf.reverse(f, axis=[-1]) |
|
f = tf.reshape(f, shape=(1, 1, f.shape[-1])) |
|
f = tf.repeat(f, repeats=num_channels, axis=0) |
|
|
|
if tf.rank(f) == 4: |
|
f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) |
|
x = tf.nn.conv2d(x, f_0, 1, 'VALID') |
|
else: |
|
f_0 = tf.expand_dims(f, axis=2) |
|
f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) |
|
|
|
f_1 = tf.expand_dims(f, axis=3) |
|
f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) |
|
|
|
x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') |
|
x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') |
|
|
|
x = x[:, :, ::downy, ::downx] |
|
|
|
|
|
x = tf.transpose(x, perm=[0, 2, 3, 1]) |
|
return x |
|
|
|
|
|
def filtered_lrelu(self, |
|
x, fu=None, fd=None, b=None, |
|
up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): |
|
|
|
px0, px1, py0, py1 = self.parse_padding(padding) |
|
|
|
x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) |
|
x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) |
|
x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) |
|
|
|
return x |
|
|
|
|
|
def call(self, inputs): |
|
x = inputs |
|
x = self.filtered_lrelu(x, |
|
fu=self.u_filter, |
|
fd=self.d_filter, |
|
b=self.bias, |
|
up=self.u_factor, |
|
down=self.d_factor, |
|
padding=self.padding, |
|
gain=self.gain, |
|
slope=self.slope, |
|
clamp=self.conv_clamp) |
|
return x |
|
|
|
def get_config(self): |
|
base_config = super(SynthesisLayerUpDownOnly, self).get_config() |
|
return base_config |
|
|
|
|
|
class Localization(Layer): |
|
def __init__(self): |
|
super(Localization, self).__init__() |
|
|
|
self.pool = MaxPooling2D() |
|
self.conv_0 = Conv2D(36, 5, activation='relu') |
|
self.conv_1 = Conv2D(36, 5, activation='relu') |
|
self.flatten = Flatten() |
|
self.fc_0 = Dense(36, activation='relu') |
|
self.fc_1 = Dense(6, bias_initializer=tf.keras.initializers.constant([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]), |
|
kernel_initializer='zeros') |
|
self.reshape = Reshape((2, 3)) |
|
|
|
def build(self, input_shape): |
|
print(input_shape) |
|
|
|
def compute_output_shape(self, input_shape): |
|
return [None, 6] |
|
|
|
def call(self, inputs): |
|
x = self.conv_0(inputs) |
|
x = self.pool(x) |
|
x = self.conv_1(x) |
|
x = self.pool(x) |
|
x = self.flatten(x) |
|
x = self.fc_0(x) |
|
theta = self.fc_1(x) |
|
theta = self.reshape(theta) |
|
|
|
return theta |
|
|
|
|
|
class BilinearInterpolation(Layer): |
|
def __init__(self, height=36, width=36): |
|
super(BilinearInterpolation, self).__init__() |
|
self.height = height |
|
self.width = width |
|
|
|
def compute_output_shape(self, input_shape): |
|
return [None, self.height, self.width, 1] |
|
|
|
def get_config(self): |
|
config = { |
|
'height': self.height, |
|
'width': self.width |
|
} |
|
base_config = super(BilinearInterpolation, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
def advance_indexing(self, inputs, x, y): |
|
shape = tf.shape(inputs) |
|
batch_size = shape[0] |
|
|
|
batch_idx = tf.range(0, batch_size) |
|
batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1)) |
|
|
|
b = tf.tile(batch_idx, (1, self.height, self.width)) |
|
indices = tf.stack([b, y, x], 3) |
|
|
|
return tf.gather_nd(inputs, indices) |
|
|
|
def grid_generator(self, batch): |
|
x = tf.linspace(-1, 1, self.width) |
|
y = tf.linspace(-1, 1, self.height) |
|
|
|
xx, yy = tf.meshgrid(x, y) |
|
xx = tf.reshape(xx, (-1,)) |
|
yy = tf.reshape(yy, (-1,)) |
|
|
|
homogenous_coordinates = tf.stack([xx, yy, tf.ones_like(xx)]) |
|
homogenous_coordinates = tf.expand_dims(homogenous_coordinates, axis=0) |
|
homogenous_coordinates = tf.tile(homogenous_coordinates, [batch, 1, 1]) |
|
homogenous_coordinates = tf.cast(homogenous_coordinates, dtype=tf.float32) |
|
return homogenous_coordinates |
|
|
|
def interpolate(self, images, homogenous_coordinates, theta): |
|
|
|
with tf.name_scope("Transformation"): |
|
transformed = tf.matmul(theta, homogenous_coordinates) |
|
transformed = tf.transpose(transformed, perm=[0, 2, 1]) |
|
transformed = tf.reshape(transformed, [-1, self.height, self.width, 2]) |
|
|
|
x_transformed = transformed[:, :, :, 0] |
|
y_transformed = transformed[:, :, :, 1] |
|
|
|
x = ((x_transformed + 1.) * tf.cast(self.width, dtype=tf.float32)) * 0.5 |
|
y = ((y_transformed + 1.) * tf.cast(self.height, dtype=tf.float32)) * 0.5 |
|
|
|
with tf.name_scope("VaribleCasting"): |
|
x0 = tf.cast(tf.math.floor(x), dtype=tf.int32) |
|
x1 = x0 + 1 |
|
y0 = tf.cast(tf.math.floor(y), dtype=tf.int32) |
|
y1 = y0 + 1 |
|
|
|
x0 = tf.clip_by_value(x0, 0, self.width-1) |
|
x1 = tf.clip_by_value(x1, 0, self.width - 1) |
|
y0 = tf.clip_by_value(y0, 0, self.height - 1) |
|
y1 = tf.clip_by_value(y1, 0, self.height - 1) |
|
x = tf.clip_by_value(x, 0, tf.cast(self.width, dtype=tf.float32) - 1.0) |
|
y = tf.clip_by_value(y, 0, tf.cast(self.height, dtype=tf.float32) - 1.0) |
|
|
|
with tf.name_scope("AdvancedIndexing"): |
|
i_a = self.advance_indexing(images, x0, y0) |
|
i_b = self.advance_indexing(images, x0, y1) |
|
i_c = self.advance_indexing(images, x1, y0) |
|
i_d = self.advance_indexing(images, x1, y1) |
|
|
|
with tf.name_scope("Interpolation"): |
|
x0 = tf.cast(x0, dtype=tf.float32) |
|
x1 = tf.cast(x1, dtype=tf.float32) |
|
y0 = tf.cast(y0, dtype=tf.float32) |
|
y1 = tf.cast(y1, dtype=tf.float32) |
|
|
|
w_a = (x1 - x) * (y1 - y) |
|
w_b = (x1 - x) * (y - y0) |
|
w_c = (x - x0) * (y1 - y) |
|
w_d = (x - x0) * (y - y0) |
|
|
|
w_a = tf.expand_dims(w_a, axis=3) |
|
w_b = tf.expand_dims(w_b, axis=3) |
|
w_c = tf.expand_dims(w_c, axis=3) |
|
w_d = tf.expand_dims(w_d, axis=3) |
|
|
|
return tf.math.add_n([w_a * i_a + w_b * i_b + w_c * i_c + w_d * i_d]) |
|
|
|
def call(self, inputs): |
|
images, theta = inputs |
|
homogenous_coordinates = self.grid_generator(batch=tf.shape(images)[0]) |
|
return self.interpolate(images, homogenous_coordinates, theta) |
|
|
|
|
|
class ResBlockLR(Layer): |
|
def __init__(self, filters=16): |
|
super(ResBlockLR, self).__init__() |
|
self.filters = filters |
|
|
|
self.conv_0 = Conv2D(filters=filters, |
|
kernel_size=3, |
|
strides=1, |
|
padding='same') |
|
self.bn_0 = BatchNormalization() |
|
self.conv_1 = Conv2D(filters=filters, |
|
kernel_size=3, |
|
strides=1, |
|
padding='same') |
|
self.bn_1 = BatchNormalization() |
|
|
|
def get_config(self): |
|
config = { |
|
'filters': self.filters, |
|
} |
|
base_config = super(ResBlockLR, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
def call(self, inputs): |
|
x = self.conv_0(inputs) |
|
x = self.bn_0(x) |
|
x = tf.nn.leaky_relu(x, alpha=0.2) |
|
x = self.conv_1(x) |
|
x = self.bn_1(x) |
|
return x + inputs |
|
|
|
|
|
class LearnedResize(Layer): |
|
def __init__(self, width, height, filters=16, in_channels=3, num_res_block=3, interpolation='bilinear'): |
|
super(LearnedResize, self).__init__() |
|
self.filters = filters |
|
self.num_res_block = num_res_block |
|
self.interpolation = interpolation |
|
self.in_channels = in_channels |
|
self.width = width |
|
self.height = height |
|
|
|
self.resize_layer = tf.keras.layers.experimental.preprocessing.Resizing(height, |
|
width, |
|
interpolation=interpolation) |
|
|
|
self.init_layers = tf.keras.models.Sequential([Conv2D(filters=filters, |
|
kernel_size=7, |
|
strides=1, |
|
padding='same'), |
|
LeakyReLU(0.2), |
|
Conv2D(filters=filters, |
|
kernel_size=1, |
|
strides=1, |
|
padding='same'), |
|
LeakyReLU(0.2), |
|
BatchNormalization() |
|
]) |
|
res_blocks = [] |
|
for i in range(num_res_block): |
|
res_blocks.append(ResBlockLR(filters=filters)) |
|
res_blocks.append(Conv2D(filters=filters, |
|
kernel_size=3, |
|
strides=1, |
|
padding='same', |
|
use_bias=False)) |
|
res_blocks.append(BatchNormalization()) |
|
self.res_block_pipe = tf.keras.models.Sequential(res_blocks) |
|
self.final_conv = Conv2D(filters=in_channels, |
|
kernel_size=3, |
|
strides=1, |
|
padding='same') |
|
|
|
|
|
def compute_output_shape(self, input_shape): |
|
return [None, self.target_size[0], self.target_size[1], input_shape[-1]] |
|
|
|
def get_config(self): |
|
config = { |
|
'filters': self.filters, |
|
'num_res_block': self.num_res_block, |
|
'interpolation': self.interpolation, |
|
'width': self.width, |
|
'height': self.height, |
|
} |
|
base_config = super(LearnedResize, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
def call(self, inputs): |
|
x_l = self.init_layers(inputs) |
|
x_l_0 = self.resize_layer(x_l) |
|
x_l = self.res_block_pipe(x_l_0) |
|
x_l = x_l + x_l_0 |
|
x_l = self.final_conv(x_l) |
|
|
|
x = self.resize_layer(inputs) |
|
|
|
return x + x_l |
|
|