from keras.layers import Conv2D, Activation, Input, Concatenate, LeakyReLU, Lambda, AveragePooling2D, UpSampling2D, Convolution2D, BatchNormalization, Conv2DTranspose, Add from keras.models import Model from InstanceNorm import InstanceNormalization def make_standard_UNET(channels,outs): def relu(x): return Activation('relu')(x) def concat(x): return Concatenate()(x) c0 = Convolution2D(filters=32, kernel_size=3, strides=1, padding='same', name='c0') c1 = Convolution2D(filters=64, kernel_size=4, strides=2, padding='same', name='c1') c2 = Convolution2D(filters=64, kernel_size=3, strides=1, padding='same', name='c2') c3 = Convolution2D(filters=128, kernel_size=4, strides=2, padding='same', name='c3') c4 = Convolution2D(filters=128, kernel_size=3, strides=1, padding='same', name='c4') c5 = Convolution2D(filters=256, kernel_size=4, strides=2, padding='same', name='c5') c6 = Convolution2D(filters=256, kernel_size=3, strides=1, padding='same', name='c6') c7 = Convolution2D(filters=512, kernel_size=4, strides=2, padding='same', name='c7') c8 = Convolution2D(filters=512, kernel_size=3, strides=1, padding='same', name='c8') bnc0 = BatchNormalization(axis=3, name='bnc0') bnc1 = BatchNormalization(axis=3, name='bnc1') bnc2 = BatchNormalization(axis=3, name='bnc2') bnc3 = BatchNormalization(axis=3, name='bnc3') bnc4 = BatchNormalization(axis=3, name='bnc4') bnc5 = BatchNormalization(axis=3, name='bnc5') bnc6 = BatchNormalization(axis=3, name='bnc6') bnc7 = BatchNormalization(axis=3, name='bnc7') bnc8 = BatchNormalization(axis=3, name='bnc8') dc8 = Conv2DTranspose(filters=512, kernel_size=4, strides=2, padding='same', name='dc8_') dc7 = Convolution2D(filters=256, kernel_size=3, strides=1, padding='same', name='dc7') dc6 = Conv2DTranspose(filters=256, kernel_size=4, strides=2, padding='same', name='dc6_') dc5 = Convolution2D(filters=128, kernel_size=3, strides=1, padding='same', name='dc5') dc4 = Conv2DTranspose(filters=128, kernel_size=4, strides=2, padding='same', name='dc4_') dc3 = Convolution2D(filters=64, kernel_size=3, strides=1, padding='same', name='dc3') dc2 = Conv2DTranspose(filters=64, kernel_size=4, strides=2, padding='same', name='dc2_') dc1 = Convolution2D(filters=32, kernel_size=3, strides=1, padding='same', name='dc1') dc0 = Convolution2D(filters=outs, kernel_size=3, strides=1, padding='same', name='dc0') bnd1 = BatchNormalization(axis=3, name='bnd1') bnd2 = BatchNormalization(axis=3, name='bnd2') bnd3 = BatchNormalization(axis=3, name='bnd3') bnd4 = BatchNormalization(axis=3, name='bnd4') bnd5 = BatchNormalization(axis=3, name='bnd5') bnd6 = BatchNormalization(axis=3, name='bnd6') bnd7 = BatchNormalization(axis=3, name='bnd7') bnd8 = BatchNormalization(axis=3, name='bnd8') x = Input(shape=(128, 128, channels)) e0 = relu(bnc0(c0(x), training = False)) e1 = relu(bnc1(c1(e0), training = False)) e2 = relu(bnc2(c2(e1), training = False)) e3 = relu(bnc3(c3(e2), training = False)) e4 = relu(bnc4(c4(e3), training = False)) e5 = relu(bnc5(c5(e4), training = False)) e6 = relu(bnc6(c6(e5), training = False)) e7 = relu(bnc7(c7(e6), training = False)) e8 = relu(bnc8(c8(e7), training = False)) d8 = relu(bnd8(dc8(concat([e7, e8])), training = False)) d7 = relu(bnd7(dc7(d8), training = False)) d6 = relu(bnd6(dc6(concat([e6, d7])), training = False)) d5 = relu(bnd5(dc5(d6), training = False)) d4 = relu(bnd4(dc4(concat([e4, d5])), training = False)) d3 = relu(bnd3(dc3(d4), training = False)) d2 = relu(bnd2(dc2(concat([e2, d3])), training = False)) d1 = relu(bnd1(dc1(d2), training = False)) d0 = dc0(concat([e0, d1])) model = Model(inputs=x,outputs=d0) return model def make_diff_net(): def conv(x, filters, name): return Conv2D(filters=filters, strides=(1, 1), kernel_size=(3, 3), padding='same', name=name)(x) def relu(x): return Activation('relu')(x) def lrelu(x): return LeakyReLU(alpha=0.1)(x) def r_block(x, filters, name=None): return relu(conv(relu(conv(x, filters, None if name is None else name + '_c1')), filters, None if name is None else name + '_c2')) def cat(a, b): return Concatenate()([UpSampling2D((2, 2))(a), b]) def dog(x): down = AveragePooling2D((2, 2))(x) up = UpSampling2D((2, 2))(down) diff = Lambda(lambda p: p[0] - p[1])([x, up]) return down, diff ip = Input(shape=(512, 512, 3)) c512 = r_block(ip, 16, 'c512') c256, l512 = dog(c512) c256 = r_block(c256, 32, 'c256') c128, l256 = dog(c256) c128 = r_block(c128, 64, 'c128') c64, l128 = dog(c128) c64 = r_block(c64, 128, 'c64') c32, l64 = dog(c64) c32 = r_block(c32, 256, 'c32') c16, l32 = dog(c32) c16 = r_block(c16, 512, 'c16') d32 = cat(c16, l32) d32 = r_block(d32, 256, 'd32') d64 = cat(d32, l64) d64 = r_block(d64, 128, 'd64') d128 = cat(d64, l128) d128 = r_block(d128, 64, 'd128') d256 = cat(d128, l256) d256 = r_block(d256, 32, 'd256') d512 = cat(d256, l512) d512 = r_block(d512, 16, 'd512') op = conv(d512, 1, 'op') return Model(inputs=ip, outputs=op) def make_wnet256(): def conv(x, filters): return Conv2D(filters=filters, strides=(1, 1), kernel_size=(3, 3), padding='same')(x) def relu(x): return Activation('relu')(x) def lrelu(x): return LeakyReLU(alpha=0.1)(x) def r_block(x, filters): return relu(conv(relu(conv(x, filters)), filters)) def res_block(x, filters): return relu(Add()([x, conv(relu(conv(x, filters)), filters)])) def cat(a, b): return Concatenate()([UpSampling2D((2, 2))(a), b]) def dog(x): down = AveragePooling2D((2, 2))(x) up = UpSampling2D((2, 2))(down) diff = Lambda(lambda p: p[0] - p[1])([x, up]) return down, diff ip_sketch = Input(shape=(256, 256, 1)) ip_color = Input(shape=(256, 256, 3)) c256 = r_block(ip_sketch, 32) c128, l256 = dog(c256) c128 = r_block(c128, 64) c64, l128 = dog(c128) c64 = r_block(c64, 128) c32, l64 = dog(c64) c32 = r_block(Concatenate()([c32, AveragePooling2D((8, 8))(ip_color)]), 256) c32 = res_block(c32, 256) c32 = res_block(c32, 256) c32 = res_block(c32, 256) c32 = res_block(c32, 256) c32 = res_block(c32, 256) c32 = res_block(c32, 256) c32 = res_block(c32, 256) c32 = res_block(c32, 256) c32 = res_block(c32, 256) d64 = cat(c32, l64) d64 = r_block(d64, 128) d128 = cat(d64, l128) d128 = r_block(d128, 64) d256 = cat(d128, l256) d256 = r_block(d256, 32) op = conv(d256, 3) return Model(inputs=[ip_sketch, ip_color], outputs=op) def make_unet512(): def conv(x, filters, strides=(1, 1), kernel_size=(3, 3)): return Conv2D(filters=filters, strides=strides, kernel_size=kernel_size, padding='same')(x) def donv(x, filters, strides=(2, 2), kernel_size=(4, 4)): return Conv2DTranspose(filters=filters, strides=strides, kernel_size=kernel_size, padding='same')(x) def relu(x): return Activation('relu')(x) def sigmoid(x): return Activation('sigmoid')(x) def norm(x): return InstanceNormalization(axis=3)(x) def cat(a, b): return Concatenate()([a, b]) def res(x, filters): c1 = relu(norm(conv(x, filters // 2))) c2 = norm(conv(c1, filters)) ad = Add()([x, c2]) return relu(ad) ip = Input(shape=(512, 512, 3)) c512 = relu(norm(conv(ip, 16, strides=(1, 1), kernel_size=(3, 3)))) c256 = relu(norm(conv(c512, 32, strides=(2, 2), kernel_size=(4, 4)))) c128 = relu(norm(conv(c256, 64, strides=(2, 2), kernel_size=(4, 4)))) c128 = res(c128, 64) c64 = relu(norm(conv(c128, 128, strides=(2, 2), kernel_size=(4, 4)))) c64 = res(c64, 128) c64 = res(c64, 128) c32 = relu(norm(conv(c64, 256, strides=(2, 2), kernel_size=(4, 4)))) c32 = res(c32, 256) c32 = res(c32, 256) c32 = res(c32, 256) c32 = res(c32, 256) c32 = res(c32, 256) c32 = res(c32, 256) c32 = res(c32, 256) c32 = res(c32, 256) c16 = relu(norm(conv(c32, 512, strides=(2, 2), kernel_size=(4, 4)))) c16 = res(c16, 512) c16 = res(c16, 512) c16 = res(c16, 512) c16 = res(c16, 512) c16 = res(c16, 512) c16 = res(c16, 512) c16 = res(c16, 512) c16 = res(c16, 512) c8 = relu(norm(conv(c16, 1024, strides=(2, 2), kernel_size=(4, 4)))) c8 = res(c8, 1024) c8 = res(c8, 1024) c8 = res(c8, 1024) c8 = res(c8, 1024) e16 = relu(norm(donv(c8, 512, strides=(2, 2), kernel_size=(4, 4)))) e16 = cat(e16, c16) e16 = relu(norm(conv(e16, 512, strides=(1, 1), kernel_size=(3, 3)))) e32 = relu(norm(donv(e16, 256, strides=(2, 2), kernel_size=(4, 4)))) e32 = cat(e32, c32) e32 = relu(norm(conv(e32, 256, strides=(1, 1), kernel_size=(3, 3)))) e64 = relu(norm(donv(e32, 128, strides=(2, 2), kernel_size=(4, 4)))) e64 = cat(e64, c64) e64 = relu(norm(conv(e64, 128, strides=(1, 1), kernel_size=(3, 3)))) e128 = relu(norm(donv(e64, 64, strides=(2, 2), kernel_size=(4, 4)))) e128 = cat(e128, c128) e128 = relu(norm(conv(e128, 64, strides=(1, 1), kernel_size=(3, 3)))) e256 = relu(norm(donv(e128, 32, strides=(2, 2), kernel_size=(4, 4)))) e256 = cat(e256, c256) e256 = relu(norm(conv(e256, 32, strides=(1, 1), kernel_size=(3, 3)))) e512 = relu(norm(donv(e256, 16, strides=(2, 2), kernel_size=(4, 4)))) e512 = cat(e512, c512) e512 = relu(norm(conv(e512, 16, strides=(1, 1), kernel_size=(3, 3)))) ot = sigmoid(conv(e512, 1)) return Model(inputs=ip, outputs=ot)