Spaces:
Running
Running
import tensorflow as tf | |
import numpy as np | |
import scipy.signal as sps | |
import scipy.special as spspec | |
import tensorflow.keras.backend as K | |
import math | |
from tensorflow.python.keras.utils import conv_utils | |
from tensorflow.keras.layers import Layer, InputSpec, Conv2D, LeakyReLU, Dense, BatchNormalization, Input, Concatenate | |
from tensorflow.keras.layers import Conv2DTranspose, ReLU, Activation, UpSampling2D, Add, Reshape, Multiply | |
from tensorflow.keras.layers import AveragePooling2D, LayerNormalization, GlobalAveragePooling2D, MaxPooling2D, Flatten | |
from tensorflow.keras import initializers, constraints, regularizers | |
from tensorflow.keras.models import Model | |
from tensorflow_addons.layers import InstanceNormalization | |
def sin_activation(x, omega=30): | |
return tf.math.sin(omega * x) | |
class AdaIN(Layer): | |
def __init__(self, **kwargs): | |
super(AdaIN, self).__init__(**kwargs) | |
def build(self, input_shapes): | |
x_shape = input_shapes[0] | |
w_shape = input_shapes[1] | |
self.w_channels = w_shape[-1] | |
self.x_channels = x_shape[-1] | |
self.dense_1 = Dense(self.x_channels) | |
self.dense_2 = Dense(self.x_channels) | |
def call(self, inputs): | |
x, w = inputs | |
ys = tf.reshape(self.dense_1(w), (-1, 1, 1, self.x_channels)) | |
yb = tf.reshape(self.dense_2(w), (-1, 1, 1, self.x_channels)) | |
return ys * x + yb | |
def get_config(self): | |
config = { | |
#'w_channels': self.w_channels, | |
#'x_channels': self.x_channels | |
} | |
base_config = super(AdaIN, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
class Conv2DMod(Layer): | |
def __init__(self, | |
filters, | |
kernel_size, | |
strides=1, | |
padding='valid', | |
dilation_rate=1, | |
kernel_initializer='glorot_uniform', | |
kernel_regularizer=None, | |
activity_regularizer=None, | |
kernel_constraint=None, | |
demod=True, | |
**kwargs): | |
super(Conv2DMod, self).__init__(**kwargs) | |
self.filters = filters | |
self.rank = 2 | |
self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size') | |
self.strides = conv_utils.normalize_tuple(strides, 2, 'strides') | |
self.padding = conv_utils.normalize_padding(padding) | |
self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2, 'dilation_rate') | |
self.kernel_initializer = initializers.get(kernel_initializer) | |
self.kernel_regularizer = regularizers.get(kernel_regularizer) | |
self.activity_regularizer = regularizers.get(activity_regularizer) | |
self.kernel_constraint = constraints.get(kernel_constraint) | |
self.demod = demod | |
self.input_spec = [InputSpec(ndim = 4), | |
InputSpec(ndim = 2)] | |
def build(self, input_shape): | |
channel_axis = -1 | |
if input_shape[0][channel_axis] is None: | |
raise ValueError('The channel dimension of the inputs ' | |
'should be defined. Found `None`.') | |
input_dim = input_shape[0][channel_axis] | |
kernel_shape = self.kernel_size + (input_dim, self.filters) | |
if input_shape[1][-1] != input_dim: | |
raise ValueError('The last dimension of modulation input should be equal to input dimension.') | |
self.kernel = self.add_weight(shape=kernel_shape, | |
initializer=self.kernel_initializer, | |
name='kernel', | |
regularizer=self.kernel_regularizer, | |
constraint=self.kernel_constraint) | |
# Set input spec. | |
self.input_spec = [InputSpec(ndim=4, axes={channel_axis: input_dim}), | |
InputSpec(ndim=2)] | |
self.built = True | |
def call(self, inputs): | |
#To channels last | |
x = tf.transpose(inputs[0], [0, 3, 1, 2]) | |
#Get weight and bias modulations | |
#Make sure w's shape is compatible with self.kernel | |
w = K.expand_dims(K.expand_dims(K.expand_dims(inputs[1], axis = 1), axis = 1), axis = -1) | |
#Add minibatch layer to weights | |
wo = K.expand_dims(self.kernel, axis = 0) | |
#Modulate | |
weights = wo * (w+1) | |
#Demodulate | |
if self.demod: | |
d = K.sqrt(K.sum(K.square(weights), axis=[1,2,3], keepdims = True) + 1e-8) | |
weights = weights / d | |
#Reshape/scale input | |
x = tf.reshape(x, [1, -1, x.shape[2], x.shape[3]]) # Fused => reshape minibatch to convolution groups. | |
w = tf.reshape(tf.transpose(weights, [1, 2, 3, 0, 4]), [weights.shape[1], weights.shape[2], weights.shape[3], -1]) | |
x = tf.nn.conv2d(x, w, | |
strides=self.strides, | |
padding="SAME", | |
data_format="NCHW") | |
# Reshape/scale output. | |
x = tf.reshape(x, [-1, self.filters, x.shape[2], x.shape[3]]) # Fused => reshape convolution groups back to minibatch. | |
x = tf.transpose(x, [0, 2, 3, 1]) | |
return x | |
def compute_output_shape(self, input_shape): | |
space = input_shape[0][1:-1] | |
new_space = [] | |
for i in range(len(space)): | |
new_dim = conv_utils.conv_output_length( | |
space[i], | |
self.kernel_size[i], | |
padding=self.padding, | |
stride=self.strides[i], | |
dilation=self.dilation_rate[i]) | |
new_space.append(new_dim) | |
return (input_shape[0],) + tuple(new_space) + (self.filters,) | |
def get_config(self): | |
config = { | |
'filters': self.filters, | |
'kernel_size': self.kernel_size, | |
'strides': self.strides, | |
'padding': self.padding, | |
'dilation_rate': self.dilation_rate, | |
'kernel_initializer': initializers.serialize(self.kernel_initializer), | |
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), | |
'activity_regularizer': | |
regularizers.serialize(self.activity_regularizer), | |
'kernel_constraint': constraints.serialize(self.kernel_constraint), | |
'demod': self.demod | |
} | |
base_config = super(Conv2DMod, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
class CreatePatches(tf.keras.layers.Layer ): | |
def __init__( self , patch_size): | |
super( CreatePatches , self).__init__() | |
self.patch_size = patch_size | |
def call(self, inputs): | |
patches = [] | |
# For square images only ( as inputs.shape[ 1 ] = inputs.shape[ 2 ] ) | |
input_image_size = inputs.shape[ 1 ] | |
for i in range( 0 , input_image_size , self.patch_size ): | |
for j in range( 0 , input_image_size , self.patch_size ): | |
patches.append( inputs[ : , i : i + self.patch_size , j : j + self.patch_size , : ] ) | |
return patches | |
def get_config(self): | |
config = {'patch_size': self.patch_size, | |
} | |
base_config = super(CreatePatches, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
class SelfAttention(tf.keras.layers.Layer ): | |
def __init__( self , alpha, filters=128): | |
super(SelfAttention , self).__init__() | |
self.alpha = alpha | |
self.filters = filters | |
self.f = Conv2D(filters, 1, 1) | |
self.g = Conv2D(filters, 1, 1) | |
self.s = Conv2D(filters, 1, 1) | |
def call(self, inputs): | |
f_map = self.f(inputs) | |
f_map = tf.image.transpose(f_map) | |
g_map = self.g(inputs) | |
s_map = self.s(inputs) | |
att = f_map * g_map | |
att = att / self.alpha | |
return tf.keras.activations.softmax(att + s_map, axis=0) | |
def get_config(self): | |
config = {'alpha': self.alpha, | |
} | |
base_config = super(SelfAttention, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
class Sampling(Layer): | |
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.""" | |
def call(self, inputs): | |
z_mean, z_log_var = inputs | |
batch = tf.shape(z_mean)[0] | |
dim = tf.shape(z_mean)[1] | |
epsilon = tf.keras.backend.random_normal(shape=(batch, dim)) | |
return z_mean + tf.exp(0.5 * z_log_var) * epsilon | |
def get_config(self): | |
config = { | |
} | |
base_config = super(Sampling, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
class ArcMarginPenaltyLogists(tf.keras.layers.Layer): | |
"""ArcMarginPenaltyLogists""" | |
def __init__(self, num_classes, margin=0.7, logist_scale=64, **kwargs): | |
super(ArcMarginPenaltyLogists, self).__init__(**kwargs) | |
self.num_classes = num_classes | |
self.margin = margin | |
self.logist_scale = logist_scale | |
def build(self, input_shape): | |
self.w = self.add_variable( | |
"weights", shape=[int(input_shape[-1]), self.num_classes]) | |
self.cos_m = tf.identity(math.cos(self.margin), name='cos_m') | |
self.sin_m = tf.identity(math.sin(self.margin), name='sin_m') | |
self.th = tf.identity(math.cos(math.pi - self.margin), name='th') | |
self.mm = tf.multiply(self.sin_m, self.margin, name='mm') | |
def call(self, embds, labels): | |
normed_embds = tf.nn.l2_normalize(embds, axis=1, name='normed_embd') | |
normed_w = tf.nn.l2_normalize(self.w, axis=0, name='normed_weights') | |
cos_t = tf.matmul(normed_embds, normed_w, name='cos_t') | |
sin_t = tf.sqrt(1. - cos_t ** 2, name='sin_t') | |
cos_mt = tf.subtract( | |
cos_t * self.cos_m, sin_t * self.sin_m, name='cos_mt') | |
cos_mt = tf.where(cos_t > self.th, cos_mt, cos_t - self.mm) | |
mask = tf.one_hot(tf.cast(labels, tf.int32), depth=self.num_classes, | |
name='one_hot_mask') | |
logists = tf.where(mask == 1., cos_mt, cos_t) | |
logists = tf.multiply(logists, self.logist_scale, 'arcface_logist') | |
return logists | |
def get_config(self): | |
config = {'num_classes': self.num_classes, | |
'margin': self.margin, | |
'logist_scale': self.logist_scale | |
} | |
base_config = super(ArcMarginPenaltyLogists, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
class KLLossLayer(tf.keras.layers.Layer): | |
"""ArcMarginPenaltyLogists""" | |
def __init__(self, beta=1.5, **kwargs): | |
super(KLLossLayer, self).__init__(**kwargs) | |
self.beta = beta | |
def call(self, inputs): | |
z_mean, z_log_var = inputs | |
kl_loss = tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) | |
kl_loss = -0.5 * kl_loss * self.beta | |
self.add_loss(kl_loss * 0) | |
self.add_metric(kl_loss, 'kl_loss') | |
return inputs | |
def get_config(self): | |
config = { | |
'beta': self.beta | |
} | |
base_config = super(KLLossLayer, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
class ReflectionPadding2D(Layer): | |
def __init__(self, padding=(1, 1), **kwargs): | |
self.padding = tuple(padding) | |
self.input_spec = [InputSpec(ndim=4)] | |
super(ReflectionPadding2D, self).__init__(**kwargs) | |
def compute_output_shape(self, s): | |
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3]) | |
def call(self, x, mask=None): | |
w_pad,h_pad = self.padding | |
return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT') | |
def get_config(self): | |
config = { | |
'padding': self.padding, | |
} | |
base_config = super(ReflectionPadding2D, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
class ResBlock(Layer): | |
def __init__(self, fil, **kwargs): | |
super(ResBlock, self).__init__(**kwargs) | |
self.fil = fil | |
self.conv_0 = Conv2D(kernel_size=3, filters=fil, strides=1) | |
self.conv_1 = Conv2D(kernel_size=3, filters=fil, strides=1) | |
self.res = Conv2D(kernel_size=1, filters=1, strides=1) | |
self.lrelu = LeakyReLU(0.2) | |
self.padding = ReflectionPadding2D(padding=(1, 1)) | |
def call(self, inputs): | |
res = self.res(inputs) | |
x = self.padding(inputs) | |
x = self.conv_0(x) | |
x = self.lrelu(x) | |
x = self.padding(x) | |
x = self.conv_1(x) | |
x = self.lrelu(x) | |
out = x + res | |
return out | |
def get_config(self): | |
config = { | |
'fil': self.fil | |
} | |
base_config = super(ResBlock, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
class SubpixelConv2D(Layer): | |
""" Subpixel Conv2D Layer | |
upsampling a layer from (h, w, c) to (h*r, w*r, c/(r*r)), | |
where r is the scaling factor, default to 4 | |
# Arguments | |
upsampling_factor: the scaling factor | |
# Input shape | |
Arbitrary. Use the keyword argument `input_shape` | |
(tuple of integers, does not include the samples axis) | |
when using this layer as the first layer in a model. | |
# Output shape | |
the second and the third dimension increased by a factor of | |
`upsampling_factor`; the last layer decreased by a factor of | |
`upsampling_factor^2`. | |
# References | |
Real-Time Single Image and Video Super-Resolution Using an Efficient | |
Sub-Pixel Convolutional Neural Network Shi et Al. https://arxiv.org/abs/1609.05158 | |
""" | |
def __init__(self, upsampling_factor=4, **kwargs): | |
super(SubpixelConv2D, self).__init__(**kwargs) | |
self.upsampling_factor = upsampling_factor | |
def build(self, input_shape): | |
last_dim = input_shape[-1] | |
factor = self.upsampling_factor * self.upsampling_factor | |
if last_dim % (factor) != 0: | |
raise ValueError('Channel ' + str(last_dim) + ' should be of ' | |
'integer times of upsampling_factor^2: ' + | |
str(factor) + '.') | |
def call(self, inputs, **kwargs): | |
return tf.nn.depth_to_space( inputs, self.upsampling_factor ) | |
def get_config(self): | |
config = { 'upsampling_factor': self.upsampling_factor, } | |
base_config = super(SubpixelConv2D, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def compute_output_shape(self, input_shape): | |
factor = self.upsampling_factor * self.upsampling_factor | |
input_shape_1 = None | |
if input_shape[1] is not None: | |
input_shape_1 = input_shape[1] * self.upsampling_factor | |
input_shape_2 = None | |
if input_shape[2] is not None: | |
input_shape_2 = input_shape[2] * self.upsampling_factor | |
dims = [ input_shape[0], | |
input_shape_1, | |
input_shape_2, | |
int(input_shape[3]/factor) | |
] | |
return tuple( dims ) | |
def id_mod_res(inputs, c): | |
feature_map, z_id = inputs | |
x = Conv2D(c, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(feature_map) | |
x = AdaIN()([x, z_id]) | |
x = ReLU()(x) | |
x = Conv2D(c, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) | |
x = AdaIN()([x, z_id]) | |
out = Add()([x, feature_map]) | |
return out | |
def id_mod_res_v2(inputs, c): | |
feature_map, z_id = inputs | |
affine = Dense(feature_map.shape[-1])(z_id) | |
x = Conv2DMod(c, kernel_size=3, padding='same', | |
kernel_regularizer=tf.keras.regularizers.l2(0.0001))([feature_map, affine]) | |
x = ReLU()(x) | |
affine = Dense(x.shape[-1])(z_id) | |
x = Conv2DMod(c, kernel_size=3, padding='same', | |
kernel_regularizer=tf.keras.regularizers.l2(0.0001))([x, affine]) | |
out = Add()([x, feature_map]) | |
x = ReLU()(x) | |
return out | |
def simswap(im_size, filter_scale=1, deep=True): | |
inputs = Input(shape=(im_size, im_size, 3)) | |
z_id = Input(shape=(512,)) | |
x = ReflectionPadding2D(padding=(3, 3))(inputs) | |
x = Conv2D(filters=64 // filter_scale, kernel_size=7, padding='valid', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 112 | |
x = BatchNormalization()(x) | |
x = Activation(tf.keras.activations.relu)(x) | |
x = Conv2D(filters=64 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 56 | |
x = BatchNormalization()(x) | |
x = Activation(tf.keras.activations.relu)(x) | |
x = Conv2D(filters=256 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 28 | |
x = BatchNormalization()(x) | |
x = Activation(tf.keras.activations.relu)(x) | |
x = Conv2D(filters=512 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 14 | |
x = BatchNormalization()(x) | |
x = Activation(tf.keras.activations.relu)(x) # 14 | |
if deep: | |
x = Conv2D(filters=512 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 7 | |
x = BatchNormalization()(x) | |
x = Activation(tf.keras.activations.relu)(x) | |
x = id_mod_res([x, z_id], 512 // filter_scale) | |
x = id_mod_res([x, z_id], 512 // filter_scale) | |
x = id_mod_res([x, z_id], 512 // filter_scale) | |
x = id_mod_res([x, z_id], 512 // filter_scale) | |
if deep: | |
x = SubpixelConv2D(upsampling_factor=2)(x) | |
x = Conv2D(filters=512 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) | |
x = BatchNormalization()(x) | |
x = Activation(tf.keras.activations.relu)(x) | |
x = SubpixelConv2D(upsampling_factor=2)(x) | |
x = Conv2D(filters=256 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) | |
x = BatchNormalization()(x) | |
x = Activation(tf.keras.activations.relu)(x) | |
x = SubpixelConv2D(upsampling_factor=2)(x) | |
x = Conv2D(filters=128 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) | |
x = BatchNormalization()(x) | |
x = Activation(tf.keras.activations.relu)(x) # 56 | |
x = SubpixelConv2D(upsampling_factor=2)(x) | |
x = Conv2D(filters=64 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) | |
x = BatchNormalization()(x) | |
x = Activation(tf.keras.activations.relu)(x) # 112 | |
x = ReflectionPadding2D(padding=(3, 3))(x) | |
out = Conv2D(filters=3, kernel_size=7, padding='valid')(x) # 112 | |
model = Model([inputs, z_id], out) | |
model.summary() | |
return model | |
def simswap_v2(deep=True): | |
inputs = Input(shape=(224, 224, 3)) | |
z_id = Input(shape=(512,)) | |
x = ReflectionPadding2D(padding=(3, 3))(inputs) | |
x = Conv2D(filters=64, kernel_size=7, padding='valid', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 112 | |
x = BatchNormalization()(x) | |
x = Activation(tf.keras.activations.relu)(x) | |
x = Conv2D(filters=64, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 56 | |
x = BatchNormalization()(x) | |
x = Activation(tf.keras.activations.relu)(x) | |
x = Conv2D(filters=256, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 28 | |
x = BatchNormalization()(x) | |
x = Activation(tf.keras.activations.relu)(x) | |
x = Conv2D(filters=512, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 14 | |
x = BatchNormalization()(x) | |
x = Activation(tf.keras.activations.relu)(x) # 14 | |
if deep: | |
x = Conv2D(filters=512, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 7 | |
x = BatchNormalization()(x) | |
x = Activation(tf.keras.activations.relu)(x) | |
x = id_mod_res_v2([x, z_id], 512) | |
x = id_mod_res_v2([x, z_id], 512) | |
x = id_mod_res_v2([x, z_id], 512) | |
x = id_mod_res_v2([x, z_id], 512) | |
x = id_mod_res_v2([x, z_id], 512) | |
x = id_mod_res_v2([x, z_id], 512) | |
if deep: | |
x = UpSampling2D(interpolation='bilinear')(x) | |
x = Conv2D(filters=512, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) | |
x = BatchNormalization()(x) | |
x = Activation(tf.keras.activations.relu)(x) | |
x = UpSampling2D(interpolation='bilinear')(x) | |
x = Conv2D(filters=256, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) | |
x = BatchNormalization()(x) | |
x = Activation(tf.keras.activations.relu)(x) | |
x = UpSampling2D(interpolation='bilinear')(x) | |
x = Conv2D(filters=128, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) | |
x = BatchNormalization()(x) | |
x = Activation(tf.keras.activations.relu)(x) # 56 | |
x = UpSampling2D(interpolation='bilinear')(x) | |
x = Conv2D(filters=64, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) | |
x = BatchNormalization()(x) | |
x = Activation(tf.keras.activations.relu)(x) # 112 | |
x = ReflectionPadding2D(padding=(3, 3))(x) | |
out = Conv2D(filters=3, kernel_size=7, padding='valid')(x) # 112 | |
out = Activation('sigmoid')(out) | |
model = Model([inputs, z_id], out) | |
model.summary() | |
return model | |
class AdaptiveAttention(Layer): | |
def __init__(self, **kwargs): | |
super(AdaptiveAttention, self).__init__(**kwargs) | |
def call(self, inputs): | |
m, a, i = inputs | |
return (1 - m) * a + m * i | |
def get_config(self): | |
base_config = super(AdaptiveAttention, self).get_config() | |
return base_config | |
def aad_block(inputs, c_out): | |
h, z_att, z_id = inputs | |
h_norm = BatchNormalization()(h) | |
h = Conv2D(filters=c_out, kernel_size=1, kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(h_norm) | |
m = Activation('sigmoid')(h) | |
z_att_gamma = Conv2D(filters=c_out, | |
kernel_size=1, | |
kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_att) | |
z_att_beta = Conv2D(filters=c_out, | |
kernel_size=1, | |
kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_att) | |
a = Multiply()([h_norm, z_att_gamma]) | |
a = Add()([a, z_att_beta]) | |
z_id_gamma = Dense(h_norm.shape[-1], | |
kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_id) | |
z_id_gamma = Reshape(target_shape=(1, 1, h_norm.shape[-1]))(z_id_gamma) | |
z_id_beta = Dense(h_norm.shape[-1], | |
kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_id) | |
z_id_beta = Reshape(target_shape=(1, 1, h_norm.shape[-1]))(z_id_beta) | |
i = Multiply()([h_norm, z_id_gamma]) | |
i = Add()([i, z_id_beta]) | |
h_out = AdaptiveAttention()([m, a, i]) | |
return h_out | |
def aad_block_mod(inputs, c_out): | |
h, z_att, z_id = inputs | |
h_norm = BatchNormalization()(h) | |
h = Conv2D(filters=c_out, kernel_size=1, kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(h_norm) | |
m = Activation('sigmoid')(h) | |
z_att_gamma = Conv2D(filters=c_out, | |
kernel_size=1, | |
kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(z_att) | |
z_att_beta = Conv2D(filters=c_out, | |
kernel_size=1, | |
kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(z_att) | |
a = Multiply()([h_norm, z_att_gamma]) | |
a = Add()([a, z_att_beta]) | |
z_id_gamma = Dense(h_norm.shape[-1], | |
kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(z_id) | |
i = Conv2DMod(filters=c_out, | |
kernel_size=1, | |
padding='same', | |
kernel_initializer='he_uniform', | |
kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))([h_norm, z_id_gamma]) | |
h_out = AdaptiveAttention()([m, a, i]) | |
return h_out | |
def aad_res_block(inputs, c_in, c_out): | |
h, z_att, z_id = inputs | |
if c_in == c_out: | |
aad = aad_block([h, z_att, z_id], c_out) | |
act = ReLU()(aad) | |
conv = Conv2D(filters=c_out, | |
kernel_size=3, | |
padding='same', | |
kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) | |
aad = aad_block([conv, z_att, z_id], c_out) | |
act = ReLU()(aad) | |
conv = Conv2D(filters=c_out, | |
kernel_size=3, | |
padding='same', | |
kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) | |
h_out = Add()([h, conv]) | |
return h_out | |
else: | |
aad = aad_block([h, z_att, z_id], c_in) | |
act = ReLU()(aad) | |
h_res = Conv2D(filters=c_out, | |
kernel_size=3, | |
padding='same', | |
kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) | |
aad = aad_block([h, z_att, z_id], c_in) | |
act = ReLU()(aad) | |
conv = Conv2D(filters=c_out, | |
kernel_size=3, | |
padding='same', | |
kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) | |
aad = aad_block([conv, z_att, z_id], c_out) | |
act = ReLU()(aad) | |
conv = Conv2D(filters=c_out, | |
kernel_size=3, | |
padding='same', | |
kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) | |
h_out = Add()([h_res, conv]) | |
return h_out | |
def aad_res_block_mod(inputs, c_in, c_out): | |
h, z_att, z_id = inputs | |
if c_in == c_out: | |
aad = aad_block_mod([h, z_att, z_id], c_out) | |
act = ReLU()(aad) | |
conv = Conv2D(filters=c_out, | |
kernel_size=3, | |
padding='same', | |
kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) | |
aad = aad_block_mod([conv, z_att, z_id], c_out) | |
act = ReLU()(aad) | |
conv = Conv2D(filters=c_out, | |
kernel_size=3, | |
padding='same', | |
kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) | |
h_out = Add()([h, conv]) | |
return h_out | |
else: | |
aad = aad_block_mod([h, z_att, z_id], c_in) | |
act = ReLU()(aad) | |
h_res = Conv2D(filters=c_out, | |
kernel_size=3, | |
padding='same', | |
kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) | |
aad = aad_block_mod([h, z_att, z_id], c_in) | |
act = ReLU()(aad) | |
conv = Conv2D(filters=c_out, | |
kernel_size=3, | |
padding='same', | |
kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) | |
aad = aad_block_mod([conv, z_att, z_id], c_out) | |
act = ReLU()(aad) | |
conv = Conv2D(filters=c_out, | |
kernel_size=3, | |
padding='same', | |
kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) | |
h_out = Add()([h_res, conv]) | |
return h_out | |
class FilteredReLU(Layer): | |
def __init__(self, | |
critically_sampled, | |
in_channels, | |
out_channels, | |
in_size, | |
out_size, | |
in_sampling_rate, | |
out_sampling_rate, | |
in_cutoff, | |
out_cutoff, | |
in_half_width, | |
out_half_width, | |
conv_kernel = 3, | |
lrelu_upsampling = 2, | |
filter_size = 6, | |
conv_clamp = 256, | |
use_radial_filters = False, | |
is_torgb = False, | |
**kwargs): | |
super(FilteredReLU, self).__init__(**kwargs) | |
self.critically_sampled = critically_sampled | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.in_size = np.broadcast_to(np.asarray(in_size), [2]) | |
self.out_size = np.broadcast_to(np.asarray(out_size), [2]) | |
self.in_sampling_rate = in_sampling_rate | |
self.out_sampling_rate = out_sampling_rate | |
self.in_cutoff = in_cutoff | |
self.out_cutoff = out_cutoff | |
self.in_half_width = in_half_width | |
self.out_half_width = out_half_width | |
self.is_torgb = is_torgb | |
self.conv_kernel = 1 if is_torgb else conv_kernel | |
self.lrelu_upsampling = lrelu_upsampling | |
self.conv_clamp = conv_clamp | |
self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) | |
# Up sampling filter | |
self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) | |
assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate | |
self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 | |
self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, | |
cutoff=self.in_cutoff, | |
width=self.in_half_width*2, | |
fs=self.tmp_sampling_rate) | |
# Down sampling filter | |
self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) | |
assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate | |
self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 | |
self.d_radial = use_radial_filters and not self.critically_sampled | |
self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, | |
cutoff=self.out_cutoff, | |
width=self.out_half_width*2, | |
fs=self.tmp_sampling_rate, | |
radial=self.d_radial) | |
# Compute padding | |
pad_total = (self.out_size - 1) * self.d_factor + 1 | |
pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor | |
pad_total += self.u_taps + self.d_taps - 2 | |
pad_lo = (pad_total + self.u_factor) // 2 | |
pad_hi = pad_total - pad_lo | |
self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] | |
self.gain = 1 if self.is_torgb else np.sqrt(2) | |
self.slope = 1 if self.is_torgb else 0.2 | |
self.act_funcs = {'linear': | |
{'func': lambda x, **_: x, | |
'def_alpha': 0, | |
'def_gain': 1}, | |
'lrelu': | |
{'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), | |
'def_alpha': 0.2, | |
'def_gain': np.sqrt(2)}, | |
} | |
b_init = tf.zeros_initializer() | |
self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), | |
dtype="float32"), | |
trainable=True) | |
def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): | |
if numtaps == 1: | |
return None | |
if not radial: | |
f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) | |
return f | |
x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs | |
r = np.hypot(*np.meshgrid(x, x)) | |
f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) | |
beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) | |
w = np.kaiser(numtaps, beta) | |
f *= np.outer(w, w) | |
f /= np.sum(f) | |
return f | |
def get_filter_size(self, f): | |
if f is None: | |
return 1, 1 | |
assert 1 <= f.ndim <= 2 | |
return f.shape[-1], f.shape[0] # width, height | |
def parse_padding(self, padding): | |
if isinstance(padding, int): | |
padding = [padding, padding] | |
assert isinstance(padding, (list, tuple)) | |
assert all(isinstance(x, (int, np.integer)) for x in padding) | |
padding = [int(x) for x in padding] | |
if len(padding) == 2: | |
px, py = padding | |
padding = [px, px, py, py] | |
px0, px1, py0, py1 = padding | |
return px0, px1, py0, py1 | |
def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): | |
spec = self.act_funcs[act] | |
alpha = float(alpha if alpha is not None else spec['def_alpha']) | |
gain = float(gain if gain is not None else spec['def_gain']) | |
clamp = float(clamp if clamp is not None else -1) | |
if b is not None: | |
x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) | |
x = spec['func'](x, alpha=alpha) | |
if gain != 1: | |
x = x * gain | |
if clamp >= 0: | |
x = tf.clip_by_value(x, -clamp, clamp) | |
return x | |
def parse_scaling(self, scaling): | |
if isinstance(scaling, int): | |
scaling = [scaling, scaling] | |
sx, sy = scaling | |
assert sx >= 1 and sy >= 1 | |
return sx, sy | |
def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): | |
if f is None: | |
f = tf.ones([1, 1], dtype=tf.float32) | |
batch_size, in_height, in_width, num_channels = x.shape | |
upx, upy = self.parse_scaling(up) | |
downx, downy = self.parse_scaling(down) | |
padx0, padx1, pady0, pady1 = self.parse_padding(padding) | |
upW = in_width * upx + padx0 + padx1 | |
upH = in_height * upy + pady0 + pady1 | |
assert upW >= f.shape[-1] and upH >= f.shape[0] | |
# Channel first format. | |
x = tf.transpose(x, perm=[0, 3, 1, 2]) | |
# Upsample by inserting zeros. | |
x = tf.reshape(x, [num_channels, batch_size, in_height, 1, in_width, 1]) | |
x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) | |
x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) | |
# Pad or crop. | |
x = tf.pad(x, [[0, 0], [0, 0], | |
[tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], | |
[tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) | |
x = x[:, :, | |
tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), | |
tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] | |
# Setup filter. | |
f = f * (gain ** (f.ndim / 2)) | |
f = tf.cast(f, dtype=x.dtype) | |
if not flip_filter: | |
f = tf.reverse(f, axis=[-1]) | |
f = tf.reshape(f, shape=(1, 1, f.shape[-1])) | |
f = tf.repeat(f, repeats=num_channels, axis=0) | |
if tf.rank(f) == 4: | |
f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) | |
x = tf.nn.conv2d(x, f_0, 1, 'VALID') | |
else: | |
f_0 = tf.expand_dims(f, axis=2) | |
f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) | |
f_1 = tf.expand_dims(f, axis=3) | |
f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) | |
x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') | |
x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') | |
x = x[:, :, ::downy, ::downx] | |
# Back to channel last. | |
x = tf.transpose(x, perm=[0, 2, 3, 1]) | |
return x | |
def filtered_lrelu(self, | |
x, fu=None, fd=None, b=None, | |
up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): | |
#fu_w, fu_h = self.get_filter_size(fu) | |
#fd_w, fd_h = self.get_filter_size(fd) | |
px0, px1, py0, py1 = self.parse_padding(padding) | |
#batch_size, in_h, in_w, channels = x.shape | |
#in_dtype = x.dtype | |
#out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down | |
#out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down | |
x = self.bias_act(x=x, b=b) | |
x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) | |
x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) | |
x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) | |
return x | |
def call(self, inputs): | |
return self.filtered_lrelu(inputs, | |
fu=self.u_filter, | |
fd=self.d_filter, | |
b=self.bias, | |
up=self.u_factor, | |
down=self.d_factor, | |
padding=self.padding, | |
gain=self.gain, | |
slope=self.slope, | |
clamp=self.conv_clamp) | |
def get_config(self): | |
base_config = super(FilteredReLU, self).get_config() | |
return base_config | |
class SynthesisLayer(Layer): | |
def __init__(self, | |
critically_sampled, | |
in_channels, | |
out_channels, | |
in_size, | |
out_size, | |
in_sampling_rate, | |
out_sampling_rate, | |
in_cutoff, | |
out_cutoff, | |
in_half_width, | |
out_half_width, | |
conv_kernel = 3, | |
lrelu_upsampling = 2, | |
filter_size = 6, | |
conv_clamp = 256, | |
use_radial_filters = False, | |
is_torgb = False, | |
**kwargs): | |
super(SynthesisLayer, self).__init__(**kwargs) | |
self.critically_sampled = critically_sampled | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.in_size = np.broadcast_to(np.asarray(in_size), [2]) | |
self.out_size = np.broadcast_to(np.asarray(out_size), [2]) | |
self.in_sampling_rate = in_sampling_rate | |
self.out_sampling_rate = out_sampling_rate | |
self.in_cutoff = in_cutoff | |
self.out_cutoff = out_cutoff | |
self.in_half_width = in_half_width | |
self.out_half_width = out_half_width | |
self.is_torgb = is_torgb | |
self.conv_kernel = 1 if is_torgb else conv_kernel | |
self.lrelu_upsampling = lrelu_upsampling | |
self.conv_clamp = conv_clamp | |
self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) | |
# Up sampling filter | |
self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) | |
assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate | |
self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 | |
self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, | |
cutoff=self.in_cutoff, | |
width=self.in_half_width*2, | |
fs=self.tmp_sampling_rate) | |
# Down sampling filter | |
self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) | |
assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate | |
self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 | |
self.d_radial = use_radial_filters and not self.critically_sampled | |
self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, | |
cutoff=self.out_cutoff, | |
width=self.out_half_width*2, | |
fs=self.tmp_sampling_rate, | |
radial=self.d_radial) | |
# Compute padding | |
pad_total = (self.out_size - 1) * self.d_factor + 1 | |
pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor | |
pad_total += self.u_taps + self.d_taps - 2 | |
pad_lo = (pad_total + self.u_factor) // 2 | |
pad_hi = pad_total - pad_lo | |
self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] | |
self.gain = 1 if self.is_torgb else np.sqrt(2) | |
self.slope = 1 if self.is_torgb else 0.2 | |
self.act_funcs = {'linear': | |
{'func': lambda x, **_: x, | |
'def_alpha': 0, | |
'def_gain': 1}, | |
'lrelu': | |
{'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), | |
'def_alpha': 0.2, | |
'def_gain': np.sqrt(2)}, | |
} | |
b_init = tf.zeros_initializer() | |
self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), | |
dtype="float32"), | |
trainable=True) | |
self.affine = Dense(self.in_channels) | |
self.conv = Conv2DMod(self.out_channels, kernel_size=self.conv_kernel, padding='same') | |
def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): | |
if numtaps == 1: | |
return None | |
if not radial: | |
f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) | |
return f | |
x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs | |
r = np.hypot(*np.meshgrid(x, x)) | |
f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) | |
beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) | |
w = np.kaiser(numtaps, beta) | |
f *= np.outer(w, w) | |
f /= np.sum(f) | |
return f | |
def get_filter_size(self, f): | |
if f is None: | |
return 1, 1 | |
assert 1 <= f.ndim <= 2 | |
return f.shape[-1], f.shape[0] # width, height | |
def parse_padding(self, padding): | |
if isinstance(padding, int): | |
padding = [padding, padding] | |
assert isinstance(padding, (list, tuple)) | |
assert all(isinstance(x, (int, np.integer)) for x in padding) | |
padding = [int(x) for x in padding] | |
if len(padding) == 2: | |
px, py = padding | |
padding = [px, px, py, py] | |
px0, px1, py0, py1 = padding | |
return px0, px1, py0, py1 | |
def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): | |
spec = self.act_funcs[act] | |
alpha = float(alpha if alpha is not None else spec['def_alpha']) | |
gain = float(gain if gain is not None else spec['def_gain']) | |
clamp = float(clamp if clamp is not None else -1) | |
if b is not None: | |
x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) | |
x = spec['func'](x, alpha=alpha) | |
if gain != 1: | |
x = x * gain | |
if clamp >= 0: | |
x = tf.clip_by_value(x, -clamp, clamp) | |
return x | |
def parse_scaling(self, scaling): | |
if isinstance(scaling, int): | |
scaling = [scaling, scaling] | |
sx, sy = scaling | |
assert sx >= 1 and sy >= 1 | |
return sx, sy | |
def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): | |
if f is None: | |
f = tf.ones([1, 1], dtype=tf.float32) | |
batch_size, in_height, in_width, num_channels = x.shape | |
upx, upy = self.parse_scaling(up) | |
downx, downy = self.parse_scaling(down) | |
padx0, padx1, pady0, pady1 = self.parse_padding(padding) | |
upW = in_width * upx + padx0 + padx1 | |
upH = in_height * upy + pady0 + pady1 | |
assert upW >= f.shape[-1] and upH >= f.shape[0] | |
# Channel first format. | |
x = tf.transpose(x, perm=[0, 3, 1, 2]) | |
# Upsample by inserting zeros. | |
x = tf.reshape(x, [num_channels, batch_size, in_height, 1, in_width, 1]) | |
x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) | |
x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) | |
# Pad or crop. | |
x = tf.pad(x, [[0, 0], [0, 0], | |
[tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], | |
[tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) | |
x = x[:, :, tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] | |
# Setup filter. | |
f = f * (gain ** (f.ndim / 2)) | |
f = tf.cast(f, dtype=x.dtype) | |
if not flip_filter: | |
f = tf.reverse(f, axis=[-1]) | |
f = tf.reshape(f, shape=(1, 1, f.shape[-1])) | |
f = tf.repeat(f, repeats=num_channels, axis=0) | |
if tf.rank(f) == 4: | |
f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) | |
x = tf.nn.conv2d(x, f_0, 1, 'VALID') | |
else: | |
f_0 = tf.expand_dims(f, axis=2) | |
f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) | |
f_1 = tf.expand_dims(f, axis=3) | |
f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) | |
x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') | |
x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') | |
x = x[:, :, ::downy, ::downx] | |
# Back to channel last. | |
x = tf.transpose(x, perm=[0, 2, 3, 1]) | |
return x | |
def filtered_lrelu(self, | |
x, fu=None, fd=None, b=None, | |
up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): | |
#fu_w, fu_h = self.get_filter_size(fu) | |
#fd_w, fd_h = self.get_filter_size(fd) | |
px0, px1, py0, py1 = self.parse_padding(padding) | |
#batch_size, in_h, in_w, channels = x.shape | |
#in_dtype = x.dtype | |
#out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down | |
#out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down | |
x = self.bias_act(x=x, b=b) | |
x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) | |
x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) | |
x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) | |
return x | |
def call(self, inputs): | |
x, w = inputs | |
styles = self.affine(w) | |
x = self.conv([x, styles]) | |
x = self.filtered_lrelu(x, | |
fu=self.u_filter, | |
fd=self.d_filter, | |
b=self.bias, | |
up=self.u_factor, | |
down=self.d_factor, | |
padding=self.padding, | |
gain=self.gain, | |
slope=self.slope, | |
clamp=self.conv_clamp) | |
return x | |
def get_config(self): | |
base_config = super(SynthesisLayer, self).get_config() | |
return base_config | |
class SynthesisLayerNoMod(Layer): | |
def __init__(self, | |
critically_sampled, | |
in_channels, | |
out_channels, | |
in_size, | |
out_size, | |
in_sampling_rate, | |
out_sampling_rate, | |
in_cutoff, | |
out_cutoff, | |
in_half_width, | |
out_half_width, | |
conv_kernel = 3, | |
lrelu_upsampling = 2, | |
filter_size = 6, | |
conv_clamp = 256, | |
use_radial_filters = False, | |
is_torgb = False, | |
batch_size = 10, | |
**kwargs): | |
super(SynthesisLayerNoMod, self).__init__(**kwargs) | |
self.critically_sampled = critically_sampled | |
self.bs = batch_size | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.in_size = np.broadcast_to(np.asarray(in_size), [2]) | |
self.out_size = np.broadcast_to(np.asarray(out_size), [2]) | |
self.in_sampling_rate = in_sampling_rate | |
self.out_sampling_rate = out_sampling_rate | |
self.in_cutoff = in_cutoff | |
self.out_cutoff = out_cutoff | |
self.in_half_width = in_half_width | |
self.out_half_width = out_half_width | |
self.is_torgb = is_torgb | |
self.conv_kernel = 1 if is_torgb else conv_kernel | |
self.lrelu_upsampling = lrelu_upsampling | |
self.conv_clamp = conv_clamp | |
self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) | |
# Up sampling filter | |
self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) | |
assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate | |
self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 | |
self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, | |
cutoff=self.in_cutoff, | |
width=self.in_half_width*2, | |
fs=self.tmp_sampling_rate) | |
# Down sampling filter | |
self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) | |
assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate | |
self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 | |
self.d_radial = use_radial_filters and not self.critically_sampled | |
self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, | |
cutoff=self.out_cutoff, | |
width=self.out_half_width*2, | |
fs=self.tmp_sampling_rate, | |
radial=self.d_radial) | |
# Compute padding | |
pad_total = (self.out_size - 1) * self.d_factor + 1 | |
pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor | |
pad_total += self.u_taps + self.d_taps - 2 | |
pad_lo = (pad_total + self.u_factor) // 2 | |
pad_hi = pad_total - pad_lo | |
self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] | |
self.gain = 1 if self.is_torgb else np.sqrt(2) | |
self.slope = 1 if self.is_torgb else 0.2 | |
self.act_funcs = {'linear': | |
{'func': lambda x, **_: x, | |
'def_alpha': 0, | |
'def_gain': 1}, | |
'lrelu': | |
{'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), | |
'def_alpha': 0.2, | |
'def_gain': np.sqrt(2)}, | |
} | |
b_init = tf.zeros_initializer() | |
self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), | |
dtype="float32"), | |
trainable=True) | |
self.conv = Conv2D(self.out_channels, kernel_size=self.conv_kernel, padding='same') | |
def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): | |
if numtaps == 1: | |
return None | |
if not radial: | |
f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) | |
return f | |
x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs | |
r = np.hypot(*np.meshgrid(x, x)) | |
f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) | |
beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) | |
w = np.kaiser(numtaps, beta) | |
f *= np.outer(w, w) | |
f /= np.sum(f) | |
return f | |
def get_filter_size(self, f): | |
if f is None: | |
return 1, 1 | |
assert 1 <= f.ndim <= 2 | |
return f.shape[-1], f.shape[0] # width, height | |
def parse_padding(self, padding): | |
if isinstance(padding, int): | |
padding = [padding, padding] | |
assert isinstance(padding, (list, tuple)) | |
assert all(isinstance(x, (int, np.integer)) for x in padding) | |
padding = [int(x) for x in padding] | |
if len(padding) == 2: | |
px, py = padding | |
padding = [px, px, py, py] | |
px0, px1, py0, py1 = padding | |
return px0, px1, py0, py1 | |
def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): | |
spec = self.act_funcs[act] | |
alpha = float(alpha if alpha is not None else spec['def_alpha']) | |
gain = float(gain if gain is not None else spec['def_gain']) | |
clamp = float(clamp if clamp is not None else -1) | |
if b is not None: | |
x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) | |
x = spec['func'](x, alpha=alpha) | |
if gain != 1: | |
x = x * gain | |
if clamp >= 0: | |
x = tf.clip_by_value(x, -clamp, clamp) | |
return x | |
def parse_scaling(self, scaling): | |
if isinstance(scaling, int): | |
scaling = [scaling, scaling] | |
sx, sy = scaling | |
assert sx >= 1 and sy >= 1 | |
return sx, sy | |
def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): | |
if f is None: | |
f = tf.ones([1, 1], dtype=tf.float32) | |
batch_size, in_height, in_width, num_channels = x.shape | |
batch_size = tf.shape(x)[0] | |
upx, upy = self.parse_scaling(up) | |
downx, downy = self.parse_scaling(down) | |
padx0, padx1, pady0, pady1 = self.parse_padding(padding) | |
upW = in_width * upx + padx0 + padx1 | |
upH = in_height * upy + pady0 + pady1 | |
assert upW >= f.shape[-1] and upH >= f.shape[0] | |
# Channel first format. | |
x = tf.transpose(x, perm=[0, 3, 1, 2]) | |
# Upsample by inserting zeros. | |
x = tf.reshape(x, [batch_size, num_channels, in_height, 1, in_width, 1]) | |
x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) | |
x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) | |
# Pad or crop. | |
x = tf.pad(x, [[0, 0], [0, 0], | |
[tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], | |
[tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) | |
x = x[:, :, | |
tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), | |
tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] | |
# Setup filter. | |
f = f * (gain ** (tf.rank(f) / 2)) | |
f = tf.cast(f, dtype=x.dtype) | |
if not flip_filter: | |
f = tf.reverse(f, axis=[-1]) | |
f = tf.reshape(f, shape=(1, 1, f.shape[-1])) | |
f = tf.repeat(f, repeats=num_channels, axis=0) | |
#if tf.rank(f) == 500: | |
# f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) | |
# x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') | |
#else: | |
f_0 = tf.expand_dims(f, axis=2) | |
f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) | |
f_1 = tf.expand_dims(f, axis=3) | |
f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) | |
x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') | |
x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') | |
x = x[:, :, ::downy, ::downx] | |
# Back to channel last. | |
x = tf.transpose(x, perm=[0, 2, 3, 1]) | |
return x | |
def filtered_lrelu(self, | |
x, fu=None, fd=None, b=None, | |
up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): | |
#fu_w, fu_h = self.get_filter_size(fu) | |
#fd_w, fd_h = self.get_filter_size(fd) | |
px0, px1, py0, py1 = self.parse_padding(padding) | |
#batch_size, in_h, in_w, channels = x.shape | |
#in_dtype = x.dtype | |
#out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down | |
#out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down | |
x = self.bias_act(x=x, b=b) | |
x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) | |
x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) | |
x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) | |
return x | |
def call(self, inputs): | |
x = inputs | |
x = self.conv(x) | |
x = self.filtered_lrelu(x, | |
fu=self.u_filter, | |
fd=self.d_filter, | |
b=self.bias, | |
up=self.u_factor, | |
down=self.d_factor, | |
padding=self.padding, | |
gain=self.gain, | |
slope=self.slope, | |
clamp=self.conv_clamp) | |
return x | |
def get_config(self): | |
base_config = super(SynthesisLayer, self).get_config() | |
return base_config | |
class SynthesisLayerNoModBN(Layer): | |
def __init__(self, | |
critically_sampled, | |
in_channels, | |
out_channels, | |
in_size, | |
out_size, | |
in_sampling_rate, | |
out_sampling_rate, | |
in_cutoff, | |
out_cutoff, | |
in_half_width, | |
out_half_width, | |
conv_kernel = 3, | |
lrelu_upsampling = 2, | |
filter_size = 6, | |
conv_clamp = 256, | |
use_radial_filters = False, | |
is_torgb = False, | |
batch_size = 10, | |
**kwargs): | |
super(SynthesisLayerNoModBN, self).__init__(**kwargs) | |
self.critically_sampled = critically_sampled | |
self.bs = batch_size | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.in_size = np.broadcast_to(np.asarray(in_size), [2]) | |
self.out_size = np.broadcast_to(np.asarray(out_size), [2]) | |
self.in_sampling_rate = in_sampling_rate | |
self.out_sampling_rate = out_sampling_rate | |
self.in_cutoff = in_cutoff | |
self.out_cutoff = out_cutoff | |
self.in_half_width = in_half_width | |
self.out_half_width = out_half_width | |
self.is_torgb = is_torgb | |
self.conv_kernel = 1 if is_torgb else conv_kernel | |
self.lrelu_upsampling = lrelu_upsampling | |
self.conv_clamp = conv_clamp | |
self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1.0 if is_torgb else lrelu_upsampling) | |
# Up sampling filter | |
self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) | |
assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate | |
self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 | |
self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, | |
cutoff=self.in_cutoff, | |
width=self.in_half_width*2, | |
fs=self.tmp_sampling_rate) | |
# Down sampling filter | |
self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) | |
assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate | |
self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 | |
self.d_radial = use_radial_filters and not self.critically_sampled | |
self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, | |
cutoff=self.out_cutoff, | |
width=self.out_half_width*2, | |
fs=self.tmp_sampling_rate, | |
radial=self.d_radial) | |
# Compute padding | |
pad_total = (self.out_size - 1) * self.d_factor + 1 | |
pad_total -= (self.in_size) * self.u_factor | |
pad_total += self.u_taps + self.d_taps - 2 | |
pad_lo = (pad_total + self.u_factor) // 2 | |
pad_hi = pad_total - pad_lo | |
self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] | |
self.gain = 1 if self.is_torgb else np.sqrt(2) | |
self.slope = 1 if self.is_torgb else 0.2 | |
self.act_funcs = {'linear': | |
{'func': lambda x, **_: x, | |
'def_alpha': 0, | |
'def_gain': 1}, | |
'lrelu': | |
{'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), | |
'def_alpha': 0.2, | |
'def_gain': np.sqrt(2)}, | |
} | |
b_init = tf.zeros_initializer() | |
self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), | |
dtype="float32"), | |
trainable=True) | |
self.conv = Conv2D(self.out_channels, kernel_size=self.conv_kernel, padding='same') | |
self.bn = BatchNormalization() | |
def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): | |
if numtaps == 1: | |
return None | |
if not radial: | |
f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) | |
return f | |
x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs | |
r = np.hypot(*np.meshgrid(x, x)) | |
f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) | |
beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) | |
w = np.kaiser(numtaps, beta) | |
f *= np.outer(w, w) | |
f /= np.sum(f) | |
return f | |
def get_filter_size(self, f): | |
if f is None: | |
return 1, 1 | |
assert 1 <= f.ndim <= 2 | |
return f.shape[-1], f.shape[0] # width, height | |
def parse_padding(self, padding): | |
if isinstance(padding, int): | |
padding = [padding, padding] | |
assert isinstance(padding, (list, tuple)) | |
assert all(isinstance(x, (int, np.integer)) for x in padding) | |
padding = [int(x) for x in padding] | |
if len(padding) == 2: | |
px, py = padding | |
padding = [px, px, py, py] | |
px0, px1, py0, py1 = padding | |
return px0, px1, py0, py1 | |
def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): | |
spec = self.act_funcs[act] | |
alpha = float(alpha if alpha is not None else spec['def_alpha']) | |
gain = tf.cast(gain if gain is not None else spec['def_gain'], tf.float32) | |
clamp = float(clamp if clamp is not None else -1) | |
if b is not None: | |
x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) | |
x = spec['func'](x, alpha=alpha) | |
x = x * gain | |
if clamp >= 0: | |
x = tf.clip_by_value(x, -clamp, clamp) | |
return x | |
def parse_scaling(self, scaling): | |
if isinstance(scaling, int): | |
scaling = [scaling, scaling] | |
sx, sy = scaling | |
assert sx >= 1 and sy >= 1 | |
return sx, sy | |
def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): | |
if f is None: | |
f = tf.ones([1, 1], dtype=tf.float32) | |
batch_size, in_height, in_width, num_channels = x.shape | |
batch_size = tf.shape(x)[0] | |
upx, upy = self.parse_scaling(up) | |
downx, downy = self.parse_scaling(down) | |
padx0, padx1, pady0, pady1 = self.parse_padding(padding) | |
upW = in_width * upx + padx0 + padx1 | |
upH = in_height * upy + pady0 + pady1 | |
assert upW >= f.shape[-1] and upH >= f.shape[0] | |
# Channel first format. | |
x = tf.transpose(x, perm=[0, 3, 1, 2]) | |
# Upsample by inserting zeros. | |
x = tf.reshape(x, [batch_size, num_channels, in_height, 1, in_width, 1]) | |
x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) | |
x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) | |
# Pad or crop. | |
x = tf.pad(x, [[0, 0], [0, 0], | |
[tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], | |
[tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) | |
x = x[:, :, | |
tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), | |
tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] | |
# Setup filter. | |
f = f * (gain ** (tf.rank(f) / 2)) | |
f = tf.cast(f, dtype=x.dtype) | |
if not flip_filter: | |
f = tf.reverse(f, axis=[-1]) | |
f = tf.reshape(f, shape=(1, 1, f.shape[-1])) | |
f = tf.repeat(f, repeats=num_channels, axis=0) | |
#if tf.rank(f) == 500: | |
# f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) | |
# x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') | |
#else: | |
f_0 = tf.expand_dims(f, axis=2) | |
f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) | |
f_1 = tf.expand_dims(f, axis=3) | |
f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) | |
x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') | |
x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') | |
x = x[:, :, ::downy, ::downx] | |
# Back to channel last. | |
x = tf.transpose(x, perm=[0, 2, 3, 1]) | |
return x | |
def filtered_lrelu(self, | |
x, fu=None, fd=None, b=None, | |
up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): | |
#fu_w, fu_h = self.get_filter_size(fu) | |
#fd_w, fd_h = self.get_filter_size(fd) | |
px0, px1, py0, py1 = self.parse_padding(padding) | |
#batch_size, in_h, in_w, channels = x.shape | |
#in_dtype = x.dtype | |
#out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down | |
#out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down | |
x = self.bias_act(x=x, b=b) | |
x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) | |
x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) | |
x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) | |
return x | |
def call(self, inputs): | |
x = inputs | |
x = self.conv(x) | |
x = self.bn(x) | |
x = self.filtered_lrelu(x, | |
fu=self.u_filter, | |
fd=self.d_filter, | |
b=self.bias, | |
up=self.u_factor, | |
down=self.d_factor, | |
padding=self.padding, | |
gain=self.gain, | |
slope=self.slope, | |
clamp=self.conv_clamp) | |
return x | |
def get_config(self): | |
base_config = super(SynthesisLayerNoModBN, self).get_config() | |
return base_config | |
class SynthesisInput(Layer): | |
def __init__(self, | |
w_dim, | |
channels, | |
size, | |
sampling_rate, | |
bandwidth, | |
**kwargs): | |
super(SynthesisInput, self).__init__(**kwargs) | |
self.w_dim = w_dim | |
self.channels = channels | |
self.size = np.broadcast_to(np.asarray(size), [2]) | |
self.sampling_rate = sampling_rate | |
self.bandwidth = bandwidth | |
# Draw random frequencies from uniform 2D disc. | |
freqs = np.random.normal(size=(int(channels[0]), 2)) | |
radii = np.sqrt(np.sum(np.square(freqs), axis=1, keepdims=True)) | |
freqs /= radii * np.power(np.exp(np.square(radii)), 0.25) | |
freqs *= bandwidth | |
phases = np.random.uniform(size=[int(channels[0])]) - 0.5 | |
# Setup parameters and buffers. | |
w_init = tf.random_normal_initializer() | |
self.weight = tf.Variable(initial_value=w_init(shape=(self.channels, self.channels), | |
dtype="float32"), rainable=True) | |
self.affine = Dense(4, kernel_initializer=tf.zeros_initializer, bias_initializer=tf.zeros_initializer) | |
self.transform = tf.eye(3, 3) | |
self.freqs = tf.constant(freqs) | |
self.phases = tf.constant(phases) | |
def call(self, w): | |
# Batch dimension | |
transforms = tf.expand_dims(self.transform, axis=0) | |
freqs = tf.expand_dims(self.freqs, axis=0) | |
phases = tf.expand_dims(self.phases, axis=0) | |
# Apply learned transformation. | |
t = self.affine(w) # t = (r_c, r_s, t_x, t_y) | |
t = t / tf.linalg.norm(t[:, :2], axis=1, keepdims=True) | |
# Inverse rotation wrt. resulting image. | |
m_r = tf.repeat(tf.expand_dims(tf.eye(3), axis=0), repeats=w.shape[0], axis=0) | |
m_r[:, 0, 0] = t[:, 0] # r'_c | |
m_r[:, 0, 1] = -t[:, 1] # r'_s | |
m_r[:, 1, 0] = t[:, 1] # r'_s | |
m_r[:, 1, 1] = t[:, 0] # r'_c | |
# Inverse translation wrt. resulting image. | |
m_t = tf.repeat(tf.expand_dims(tf.eye(3), axis=0), repeats=w.shape[0], axis=0) | |
m_t[:, 0, 2] = -t[:, 2] # t'_x | |
m_t[:, 1, 2] = -t[:, 3] # t'_y | |
transforms = m_r @ m_t @ transforms | |
# Transform frequencies. | |
phases = phases + tf.expand_dims(freqs @ transforms[:, :2, 2:], axis=2) | |
freqs = freqs @ transforms[:, :2, :2] | |
# Dampen out-of-band frequencies that may occur due to the user-specified transform. | |
amplitudes = tf.clip_by_value(1 - (tf.linalg.norm(freqs, axis=1, keepdims=True) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth), 0, 1) | |
def get_config(self): | |
base_config = super(SynthesisInput, self).get_config() | |
return base_config | |
class SynthesisLayerFS(Layer): | |
def __init__(self, | |
critically_sampled, | |
in_channels, | |
out_channels, | |
in_size, | |
out_size, | |
in_sampling_rate, | |
out_sampling_rate, | |
in_cutoff, | |
out_cutoff, | |
in_half_width, | |
out_half_width, | |
conv_kernel = 3, | |
lrelu_upsampling = 2, | |
filter_size = 6, | |
conv_clamp = 256, | |
use_radial_filters = False, | |
is_torgb = False, | |
**kwargs): | |
super(SynthesisLayerFS, self).__init__(**kwargs) | |
self.critically_sampled = critically_sampled | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.in_size = np.broadcast_to(np.asarray(in_size), [2]) | |
self.out_size = np.broadcast_to(np.asarray(out_size), [2]) | |
self.in_sampling_rate = in_sampling_rate | |
self.out_sampling_rate = out_sampling_rate | |
self.in_cutoff = in_cutoff | |
self.out_cutoff = out_cutoff | |
self.in_half_width = in_half_width | |
self.out_half_width = out_half_width | |
self.is_torgb = is_torgb | |
self.conv_kernel = 1 if is_torgb else conv_kernel | |
self.lrelu_upsampling = lrelu_upsampling | |
self.conv_clamp = conv_clamp | |
self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) | |
# Up sampling filter | |
self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) | |
assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate | |
self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 | |
self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, | |
cutoff=self.in_cutoff, | |
width=self.in_half_width*2, | |
fs=self.tmp_sampling_rate) | |
# Down sampling filter | |
self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) | |
assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate | |
self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 | |
self.d_radial = use_radial_filters and not self.critically_sampled | |
self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, | |
cutoff=self.out_cutoff, | |
width=self.out_half_width*2, | |
fs=self.tmp_sampling_rate, | |
radial=self.d_radial) | |
# Compute padding | |
pad_total = (self.out_size - 1) * self.d_factor + 1 | |
pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor | |
pad_total += self.u_taps + self.d_taps - 2 | |
pad_lo = (pad_total + self.u_factor) // 2 | |
pad_hi = pad_total - pad_lo | |
self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] | |
self.gain = 1 if self.is_torgb else np.sqrt(2) | |
self.slope = 1 if self.is_torgb else 0.2 | |
self.act_funcs = {'linear': | |
{'func': lambda x, **_: x, | |
'def_alpha': 0, | |
'def_gain': 1}, | |
'lrelu': | |
{'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), | |
'def_alpha': 0.2, | |
'def_gain': np.sqrt(2)}, | |
} | |
b_init = tf.zeros_initializer() | |
self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), | |
dtype="float32"), | |
trainable=True) | |
self.affine = Dense(self.out_channels) | |
self.conv_mod = Conv2DMod(self.out_channels, kernel_size=self.conv_kernel, padding='same') | |
self.bn = BatchNormalization() | |
self.conv_gamma = Conv2D(self.out_channels, kernel_size=1) | |
self.conv_beta = Conv2D(self.out_channels, kernel_size=1) | |
self.conv_gate = Conv2D(self.out_channels, kernel_size=1) | |
self.conv_final = Conv2D(self.out_channels, kernel_size=self.conv_kernel, padding='same') | |
def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): | |
if numtaps == 1: | |
return None | |
if not radial: | |
f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) | |
return f | |
x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs | |
r = np.hypot(*np.meshgrid(x, x)) | |
f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) | |
beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) | |
w = np.kaiser(numtaps, beta) | |
f *= np.outer(w, w) | |
f /= np.sum(f) | |
return f | |
def get_filter_size(self, f): | |
if f is None: | |
return 1, 1 | |
assert 1 <= f.ndim <= 2 | |
return f.shape[-1], f.shape[0] # width, height | |
def parse_padding(self, padding): | |
if isinstance(padding, int): | |
padding = [padding, padding] | |
assert isinstance(padding, (list, tuple)) | |
assert all(isinstance(x, (int, np.integer)) for x in padding) | |
padding = [int(x) for x in padding] | |
if len(padding) == 2: | |
px, py = padding | |
padding = [px, px, py, py] | |
px0, px1, py0, py1 = padding | |
return px0, px1, py0, py1 | |
def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): | |
spec = self.act_funcs[act] | |
alpha = float(alpha if alpha is not None else spec['def_alpha']) | |
gain = tf.cast(gain if gain is not None else spec['def_gain'], tf.float32) | |
clamp = float(clamp if clamp is not None else -1) | |
if b is not None: | |
x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) | |
x = spec['func'](x, alpha=alpha) | |
x = x * gain | |
if clamp >= 0: | |
x = tf.clip_by_value(x, -clamp, clamp) | |
return x | |
def parse_scaling(self, scaling): | |
if isinstance(scaling, int): | |
scaling = [scaling, scaling] | |
sx, sy = scaling | |
assert sx >= 1 and sy >= 1 | |
return sx, sy | |
def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): | |
if f is None: | |
f = tf.ones([1, 1], dtype=tf.float32) | |
batch_size, in_height, in_width, num_channels = x.shape | |
batch_size = tf.shape(x)[0] | |
upx, upy = self.parse_scaling(up) | |
downx, downy = self.parse_scaling(down) | |
padx0, padx1, pady0, pady1 = self.parse_padding(padding) | |
upW = in_width * upx + padx0 + padx1 | |
upH = in_height * upy + pady0 + pady1 | |
assert upW >= f.shape[-1] and upH >= f.shape[0] | |
# Channel first format. | |
x = tf.transpose(x, perm=[0, 3, 1, 2]) | |
# Upsample by inserting zeros. | |
x = tf.reshape(x, [batch_size, num_channels, in_height, 1, in_width, 1]) | |
x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) | |
x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) | |
# Pad or crop. | |
x = tf.pad(x, [[0, 0], [0, 0], | |
[tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], | |
[tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) | |
x = x[:, :, | |
tf.math.maximum(-pady0, 0): x.shape[2] - tf.math.maximum(-pady1, 0), | |
tf.math.maximum(-padx0, 0): x.shape[3] - tf.math.maximum(-padx1, 0)] | |
# Setup filter. | |
f = f * (gain ** (tf.rank(f) / 2)) | |
f = tf.cast(f, dtype=x.dtype) | |
if not flip_filter: | |
f = tf.reverse(f, axis=[-1]) | |
f = tf.reshape(f, shape=(1, 1, f.shape[-1])) | |
f = tf.repeat(f, repeats=num_channels, axis=0) | |
# if tf.rank(f) == 500: | |
# f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) | |
# x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') | |
# else: | |
f_0 = tf.expand_dims(f, axis=2) | |
f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) | |
f_1 = tf.expand_dims(f, axis=3) | |
f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) | |
x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') | |
x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') | |
x = x[:, :, ::downy, ::downx] | |
# Back to channel last. | |
x = tf.transpose(x, perm=[0, 2, 3, 1]) | |
return x | |
def filtered_lrelu(self, | |
x, fu=None, fd=None, b=None, | |
up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): | |
#fu_w, fu_h = self.get_filter_size(fu) | |
#fd_w, fd_h = self.get_filter_size(fd) | |
px0, px1, py0, py1 = self.parse_padding(padding) | |
#batch_size, in_h, in_w, channels = x.shape | |
#in_dtype = x.dtype | |
#out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down | |
#out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down | |
x = self.bias_act(x=x, b=b) | |
x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) | |
x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) | |
x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) | |
return x | |
def aadm(self, x, w, a): | |
w_affine = self.affine(w) | |
x_norm = self.bn(x) | |
x_id = self.conv_mod([x_norm, w_affine]) | |
gate = self.conv_gate(x_norm) | |
gate = tf.nn.sigmoid(gate) | |
x_att_beta = self.conv_beta(a) | |
x_att_gamma = self.conv_gamma(a) | |
x_att = x_norm * x_att_beta + x_att_gamma | |
h = x_id * gate + (1 - gate) * x_att | |
return h | |
def call(self, inputs): | |
x, w, a = inputs | |
x = self.conv_final(x) | |
x = self.aadm(x, w, a) | |
x = self.filtered_lrelu(x, | |
fu=self.u_filter, | |
fd=self.d_filter, | |
b=self.bias, | |
up=self.u_factor, | |
down=self.d_factor, | |
padding=self.padding, | |
gain=self.gain, | |
slope=self.slope, | |
clamp=self.conv_clamp) | |
return x | |
def get_config(self): | |
base_config = super(SynthesisLayerFS, self).get_config() | |
return base_config | |
class SynthesisLayerUpDownOnly(Layer): | |
def __init__(self, | |
critically_sampled, | |
in_channels, | |
out_channels, | |
in_size, | |
out_size, | |
in_sampling_rate, | |
out_sampling_rate, | |
in_cutoff, | |
out_cutoff, | |
in_half_width, | |
out_half_width, | |
conv_kernel = 3, | |
lrelu_upsampling = 2, | |
filter_size = 6, | |
conv_clamp = 256, | |
use_radial_filters = False, | |
is_torgb = False, | |
**kwargs): | |
super(SynthesisLayerUpDownOnly, self).__init__(**kwargs) | |
self.critically_sampled = critically_sampled | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.in_size = np.broadcast_to(np.asarray(in_size), [2]) | |
self.out_size = np.broadcast_to(np.asarray(out_size), [2]) | |
self.in_sampling_rate = in_sampling_rate | |
self.out_sampling_rate = out_sampling_rate | |
self.in_cutoff = in_cutoff | |
self.out_cutoff = out_cutoff | |
self.in_half_width = in_half_width | |
self.out_half_width = out_half_width | |
self.is_torgb = is_torgb | |
self.conv_kernel = 1 if is_torgb else conv_kernel | |
self.lrelu_upsampling = lrelu_upsampling | |
self.conv_clamp = conv_clamp | |
self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) | |
# Up sampling filter | |
self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) | |
assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate | |
self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 | |
self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, | |
cutoff=self.in_cutoff, | |
width=self.in_half_width*2, | |
fs=self.tmp_sampling_rate) | |
# Down sampling filter | |
self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) | |
assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate | |
self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 | |
self.d_radial = use_radial_filters and not self.critically_sampled | |
self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, | |
cutoff=self.out_cutoff, | |
width=self.out_half_width*2, | |
fs=self.tmp_sampling_rate, | |
radial=self.d_radial) | |
# Compute padding | |
pad_total = (self.out_size - 1) * self.d_factor + 1 | |
pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor | |
pad_total += self.u_taps + self.d_taps - 2 | |
pad_lo = (pad_total + self.u_factor) // 2 | |
pad_hi = pad_total - pad_lo | |
self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] | |
self.gain = 1 if self.is_torgb else np.sqrt(2) | |
self.slope = 1 if self.is_torgb else 0.2 | |
self.act_funcs = {'linear': | |
{'func': lambda x, **_: x, | |
'def_alpha': 0, | |
'def_gain': 1}, | |
'lrelu': | |
{'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), | |
'def_alpha': 0.2, | |
'def_gain': np.sqrt(2)}, | |
} | |
def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): | |
if numtaps == 1: | |
return None | |
if not radial: | |
f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) | |
return f | |
x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs | |
r = np.hypot(*np.meshgrid(x, x)) | |
f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) | |
beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) | |
w = np.kaiser(numtaps, beta) | |
f *= np.outer(w, w) | |
f /= np.sum(f) | |
return f | |
def get_filter_size(self, f): | |
if f is None: | |
return 1, 1 | |
assert 1 <= f.ndim <= 2 | |
return f.shape[-1], f.shape[0] # width, height | |
def parse_padding(self, padding): | |
if isinstance(padding, int): | |
padding = [padding, padding] | |
assert isinstance(padding, (list, tuple)) | |
assert all(isinstance(x, (int, np.integer)) for x in padding) | |
padding = [int(x) for x in padding] | |
if len(padding) == 2: | |
px, py = padding | |
padding = [px, px, py, py] | |
px0, px1, py0, py1 = padding | |
return px0, px1, py0, py1 | |
def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): | |
spec = self.act_funcs[act] | |
alpha = float(alpha if alpha is not None else spec['def_alpha']) | |
gain = float(gain if gain is not None else spec['def_gain']) | |
clamp = float(clamp if clamp is not None else -1) | |
if b is not None: | |
x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) | |
x = spec['func'](x, alpha=alpha) | |
if gain != 1: | |
x = x * gain | |
if clamp >= 0: | |
x = tf.clip_by_value(x, -clamp, clamp) | |
return x | |
def parse_scaling(self, scaling): | |
if isinstance(scaling, int): | |
scaling = [scaling, scaling] | |
sx, sy = scaling | |
assert sx >= 1 and sy >= 1 | |
return sx, sy | |
def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): | |
if f is None: | |
f = tf.ones([1, 1], dtype=tf.float32) | |
batch_size, in_height, in_width, num_channels = x.shape | |
upx, upy = self.parse_scaling(up) | |
downx, downy = self.parse_scaling(down) | |
padx0, padx1, pady0, pady1 = self.parse_padding(padding) | |
upW = in_width * upx + padx0 + padx1 | |
upH = in_height * upy + pady0 + pady1 | |
assert upW >= f.shape[-1] and upH >= f.shape[0] | |
# Channel first format. | |
x = tf.transpose(x, perm=[0, 3, 1, 2]) | |
# Upsample by inserting zeros. | |
x = tf.reshape(x, [num_channels, batch_size, in_height, 1, in_width, 1]) | |
x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) | |
x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) | |
# Pad or crop. | |
x = tf.pad(x, [[0, 0], [0, 0], | |
[tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], | |
[tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) | |
x = x[:, :, tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] | |
# Setup filter. | |
f = f * (gain ** (f.ndim / 2)) | |
f = tf.cast(f, dtype=x.dtype) | |
if not flip_filter: | |
f = tf.reverse(f, axis=[-1]) | |
f = tf.reshape(f, shape=(1, 1, f.shape[-1])) | |
f = tf.repeat(f, repeats=num_channels, axis=0) | |
if tf.rank(f) == 4: | |
f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) | |
x = tf.nn.conv2d(x, f_0, 1, 'VALID') | |
else: | |
f_0 = tf.expand_dims(f, axis=2) | |
f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) | |
f_1 = tf.expand_dims(f, axis=3) | |
f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) | |
x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') | |
x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') | |
x = x[:, :, ::downy, ::downx] | |
# Back to channel last. | |
x = tf.transpose(x, perm=[0, 2, 3, 1]) | |
return x | |
def filtered_lrelu(self, | |
x, fu=None, fd=None, b=None, | |
up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): | |
px0, px1, py0, py1 = self.parse_padding(padding) | |
x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) | |
x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) | |
x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) | |
return x | |
def call(self, inputs): | |
x = inputs | |
x = self.filtered_lrelu(x, | |
fu=self.u_filter, | |
fd=self.d_filter, | |
b=self.bias, | |
up=self.u_factor, | |
down=self.d_factor, | |
padding=self.padding, | |
gain=self.gain, | |
slope=self.slope, | |
clamp=self.conv_clamp) | |
return x | |
def get_config(self): | |
base_config = super(SynthesisLayerUpDownOnly, self).get_config() | |
return base_config | |
class Localization(Layer): | |
def __init__(self): | |
super(Localization, self).__init__() | |
self.pool = MaxPooling2D() | |
self.conv_0 = Conv2D(36, 5, activation='relu') | |
self.conv_1 = Conv2D(36, 5, activation='relu') | |
self.flatten = Flatten() | |
self.fc_0 = Dense(36, activation='relu') | |
self.fc_1 = Dense(6, bias_initializer=tf.keras.initializers.constant([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]), | |
kernel_initializer='zeros') | |
self.reshape = Reshape((2, 3)) | |
def build(self, input_shape): | |
print(input_shape) | |
def compute_output_shape(self, input_shape): | |
return [None, 6] | |
def call(self, inputs): | |
x = self.conv_0(inputs) | |
x = self.pool(x) | |
x = self.conv_1(x) | |
x = self.pool(x) | |
x = self.flatten(x) | |
x = self.fc_0(x) | |
theta = self.fc_1(x) | |
theta = self.reshape(theta) | |
return theta | |
class BilinearInterpolation(Layer): | |
def __init__(self, height=36, width=36): | |
super(BilinearInterpolation, self).__init__() | |
self.height = height | |
self.width = width | |
def compute_output_shape(self, input_shape): | |
return [None, self.height, self.width, 1] | |
def get_config(self): | |
config = { | |
'height': self.height, | |
'width': self.width | |
} | |
base_config = super(BilinearInterpolation, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def advance_indexing(self, inputs, x, y): | |
shape = tf.shape(inputs) | |
batch_size = shape[0] | |
batch_idx = tf.range(0, batch_size) | |
batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1)) | |
b = tf.tile(batch_idx, (1, self.height, self.width)) | |
indices = tf.stack([b, y, x], 3) | |
return tf.gather_nd(inputs, indices) | |
def grid_generator(self, batch): | |
x = tf.linspace(-1, 1, self.width) | |
y = tf.linspace(-1, 1, self.height) | |
xx, yy = tf.meshgrid(x, y) | |
xx = tf.reshape(xx, (-1,)) | |
yy = tf.reshape(yy, (-1,)) | |
homogenous_coordinates = tf.stack([xx, yy, tf.ones_like(xx)]) | |
homogenous_coordinates = tf.expand_dims(homogenous_coordinates, axis=0) | |
homogenous_coordinates = tf.tile(homogenous_coordinates, [batch, 1, 1]) | |
homogenous_coordinates = tf.cast(homogenous_coordinates, dtype=tf.float32) | |
return homogenous_coordinates | |
def interpolate(self, images, homogenous_coordinates, theta): | |
with tf.name_scope("Transformation"): | |
transformed = tf.matmul(theta, homogenous_coordinates) | |
transformed = tf.transpose(transformed, perm=[0, 2, 1]) | |
transformed = tf.reshape(transformed, [-1, self.height, self.width, 2]) | |
x_transformed = transformed[:, :, :, 0] | |
y_transformed = transformed[:, :, :, 1] | |
x = ((x_transformed + 1.) * tf.cast(self.width, dtype=tf.float32)) * 0.5 | |
y = ((y_transformed + 1.) * tf.cast(self.height, dtype=tf.float32)) * 0.5 | |
with tf.name_scope("VaribleCasting"): | |
x0 = tf.cast(tf.math.floor(x), dtype=tf.int32) | |
x1 = x0 + 1 | |
y0 = tf.cast(tf.math.floor(y), dtype=tf.int32) | |
y1 = y0 + 1 | |
x0 = tf.clip_by_value(x0, 0, self.width-1) | |
x1 = tf.clip_by_value(x1, 0, self.width - 1) | |
y0 = tf.clip_by_value(y0, 0, self.height - 1) | |
y1 = tf.clip_by_value(y1, 0, self.height - 1) | |
x = tf.clip_by_value(x, 0, tf.cast(self.width, dtype=tf.float32) - 1.0) | |
y = tf.clip_by_value(y, 0, tf.cast(self.height, dtype=tf.float32) - 1.0) | |
with tf.name_scope("AdvancedIndexing"): | |
i_a = self.advance_indexing(images, x0, y0) | |
i_b = self.advance_indexing(images, x0, y1) | |
i_c = self.advance_indexing(images, x1, y0) | |
i_d = self.advance_indexing(images, x1, y1) | |
with tf.name_scope("Interpolation"): | |
x0 = tf.cast(x0, dtype=tf.float32) | |
x1 = tf.cast(x1, dtype=tf.float32) | |
y0 = tf.cast(y0, dtype=tf.float32) | |
y1 = tf.cast(y1, dtype=tf.float32) | |
w_a = (x1 - x) * (y1 - y) | |
w_b = (x1 - x) * (y - y0) | |
w_c = (x - x0) * (y1 - y) | |
w_d = (x - x0) * (y - y0) | |
w_a = tf.expand_dims(w_a, axis=3) | |
w_b = tf.expand_dims(w_b, axis=3) | |
w_c = tf.expand_dims(w_c, axis=3) | |
w_d = tf.expand_dims(w_d, axis=3) | |
return tf.math.add_n([w_a * i_a + w_b * i_b + w_c * i_c + w_d * i_d]) | |
def call(self, inputs): | |
images, theta = inputs | |
homogenous_coordinates = self.grid_generator(batch=tf.shape(images)[0]) | |
return self.interpolate(images, homogenous_coordinates, theta) | |
class ResBlockLR(Layer): | |
def __init__(self, filters=16): | |
super(ResBlockLR, self).__init__() | |
self.filters = filters | |
self.conv_0 = Conv2D(filters=filters, | |
kernel_size=3, | |
strides=1, | |
padding='same') | |
self.bn_0 = BatchNormalization() | |
self.conv_1 = Conv2D(filters=filters, | |
kernel_size=3, | |
strides=1, | |
padding='same') | |
self.bn_1 = BatchNormalization() | |
def get_config(self): | |
config = { | |
'filters': self.filters, | |
} | |
base_config = super(ResBlockLR, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def call(self, inputs): | |
x = self.conv_0(inputs) | |
x = self.bn_0(x) | |
x = tf.nn.leaky_relu(x, alpha=0.2) | |
x = self.conv_1(x) | |
x = self.bn_1(x) | |
return x + inputs | |
class LearnedResize(Layer): | |
def __init__(self, width, height, filters=16, in_channels=3, num_res_block=3, interpolation='bilinear'): | |
super(LearnedResize, self).__init__() | |
self.filters = filters | |
self.num_res_block = num_res_block | |
self.interpolation = interpolation | |
self.in_channels = in_channels | |
self.width = width | |
self.height = height | |
self.resize_layer = tf.keras.layers.experimental.preprocessing.Resizing(height, | |
width, | |
interpolation=interpolation) | |
self.init_layers = tf.keras.models.Sequential([Conv2D(filters=filters, | |
kernel_size=7, | |
strides=1, | |
padding='same'), | |
LeakyReLU(0.2), | |
Conv2D(filters=filters, | |
kernel_size=1, | |
strides=1, | |
padding='same'), | |
LeakyReLU(0.2), | |
BatchNormalization() | |
]) | |
res_blocks = [] | |
for i in range(num_res_block): | |
res_blocks.append(ResBlockLR(filters=filters)) | |
res_blocks.append(Conv2D(filters=filters, | |
kernel_size=3, | |
strides=1, | |
padding='same', | |
use_bias=False)) | |
res_blocks.append(BatchNormalization()) | |
self.res_block_pipe = tf.keras.models.Sequential(res_blocks) | |
self.final_conv = Conv2D(filters=in_channels, | |
kernel_size=3, | |
strides=1, | |
padding='same') | |
def compute_output_shape(self, input_shape): | |
return [None, self.target_size[0], self.target_size[1], input_shape[-1]] | |
def get_config(self): | |
config = { | |
'filters': self.filters, | |
'num_res_block': self.num_res_block, | |
'interpolation': self.interpolation, | |
'width': self.width, | |
'height': self.height, | |
} | |
base_config = super(LearnedResize, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def call(self, inputs): | |
x_l = self.init_layers(inputs) | |
x_l_0 = self.resize_layer(x_l) | |
x_l = self.res_block_pipe(x_l_0) | |
x_l = x_l + x_l_0 | |
x_l = self.final_conv(x_l) | |
x = self.resize_layer(inputs) | |
return x + x_l | |