# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # This work is licensed under the Creative Commons Attribution-NonCommercial # 4.0 International License. To view a copy of this license, visit # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. import pickle import inspect import numpy as np import tfutil import networks #---------------------------------------------------------------------------- # Custom unpickler that is able to load network pickles produced by # the old Theano implementation. class LegacyUnpickler(pickle.Unpickler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def find_class(self, module, name): if module == 'network' and name == 'Network': return tfutil.Network return super().find_class(module, name) #---------------------------------------------------------------------------- # Import handler for tfutil.Network that silently converts networks produced # by the old Theano implementation to a suitable format. theano_gan_remap = { 'G_paper': 'G_paper', 'G_progressive_8': 'G_paper', 'D_paper': 'D_paper', 'D_progressive_8': 'D_paper'} def patch_theano_gan(state): if 'version' in state or state['build_func_spec']['func'] not in theano_gan_remap: return state spec = dict(state['build_func_spec']) func = spec.pop('func') resolution = spec.get('resolution', 32) resolution_log2 = int(np.log2(resolution)) use_wscale = spec.get('use_wscale', True) assert spec.pop('label_size', 0) == 0 assert spec.pop('use_batchnorm', False) == False assert spec.pop('tanh_at_end', None) is None assert spec.pop('mbstat_func', 'Tstdeps') == 'Tstdeps' assert spec.pop('mbstat_avg', 'all') == 'all' assert spec.pop('mbdisc_kernels', None) is None spec.pop( 'use_gdrop', True) # doesn't make a difference assert spec.pop('use_layernorm', False) == False spec[ 'fused_scale'] = False spec[ 'mbstd_group_size'] = 16 vars = [] param_iter = iter(state['param_values']) relu = np.sqrt(2); linear = 1.0 def flatten2(w): return w.reshape(w.shape[0], -1) def he_std(gain, w): return gain / np.sqrt(np.prod(w.shape[:-1])) def wscale(gain, w): return w * next(param_iter) / he_std(gain, w) if use_wscale else w def layer(name, gain, w): return [(name + '/weight', wscale(gain, w)), (name + '/bias', next(param_iter))] if func.startswith('G'): vars += layer('4x4/Dense', relu/4, flatten2(next(param_iter).transpose(1,0,2,3))) vars += layer('4x4/Conv', relu, next(param_iter).transpose(2,3,1,0)[::-1,::-1]) for res in range(3, resolution_log2 + 1): vars += layer('%dx%d/Conv0' % (2**res, 2**res), relu, next(param_iter).transpose(2,3,1,0)[::-1,::-1]) vars += layer('%dx%d/Conv1' % (2**res, 2**res), relu, next(param_iter).transpose(2,3,1,0)[::-1,::-1]) for lod in range(0, resolution_log2 - 1): vars += layer('ToRGB_lod%d' % lod, linear, next(param_iter)[np.newaxis, np.newaxis]) if func.startswith('D'): vars += layer('FromRGB_lod0', relu, next(param_iter)[np.newaxis, np.newaxis]) for res in range(resolution_log2, 2, -1): vars += layer('%dx%d/Conv0' % (2**res, 2**res), relu, next(param_iter).transpose(2,3,1,0)[::-1,::-1]) vars += layer('%dx%d/Conv1' % (2**res, 2**res), relu, next(param_iter).transpose(2,3,1,0)[::-1,::-1]) vars += layer('FromRGB_lod%d' % (resolution_log2 - (res - 1)), relu, next(param_iter)[np.newaxis, np.newaxis]) vars += layer('4x4/Conv', relu, next(param_iter).transpose(2,3,1,0)[::-1,::-1]) vars += layer('4x4/Dense0', relu, flatten2(next(param_iter)[:,:,::-1,::-1]).transpose()) vars += layer('4x4/Dense1', linear, next(param_iter)) vars += [('lod', state['toplevel_params']['cur_lod'])] return { 'version': 2, 'name': func, 'build_module_src': inspect.getsource(networks), 'build_func_name': theano_gan_remap[func], 'static_kwargs': spec, 'variables': vars} tfutil.network_import_handlers.append(patch_theano_gan) #---------------------------------------------------------------------------- # Import handler for tfutil.Network that ignores unsupported/deprecated # networks produced by older versions of the code. def ignore_unknown_theano_network(state): if 'version' in state: return state print('Ignoring unknown Theano network:', state['build_func_spec']['func']) return { 'version': 2, 'name': 'Dummy', 'build_module_src': 'def dummy(input, **kwargs): input.set_shape([None, 1]); return input', 'build_func_name': 'dummy', 'static_kwargs': {}, 'variables': []} tfutil.network_import_handlers.append(ignore_unknown_theano_network) #----------------------------------------------------------------------------