Gwern-TWDNEv3-CPU-Generator / run_training.py
hr16's picture
Fork adriansahlman's stylegan2_pytorch
480bfbc
import warnings
import os
import torch
from torch import multiprocessing as mp
import stylegan2
from stylegan2 import utils
from stylegan2.external_models import inception, lpips
from stylegan2.metrics import fid, ppl
#----------------------------------------------------------------------------
def get_arg_parser():
parser = utils.ConfigArgumentParser()
parser.add_argument(
'--output',
help='Output directory for model weights.',
type=str,
default=None,
metavar='DIR'
)
#----------------------------------------------------------------------------
# Model options
parser.add_argument(
'--channels',
help='Specify the channels for each layer (can be overriden for individual ' + \
'networks with "--g_channels" and "--d_channels". ' + \
'Default: %(default)s',
nargs='*',
type=int,
default=[32, 32, 64, 128, 256, 512, 512, 512, 512],
metavar='CHANNELS'
)
parser.add_argument(
'--latent',
help='Size of the prior (noise vector). Default: %(default)s',
type=int,
default=512,
metavar='VALUE'
)
parser.add_argument(
'--label',
help='Number of unique labels. Unused if not specified.',
type=int,
default=0,
metavar='VALUE'
)
parser.add_argument(
'--base_shape',
help='Data shape of first layer in generator or ' + \
'last layer in discriminator. Default: %(default)s',
nargs=2,
type=int,
default=(4, 4),
metavar='SIZE'
)
parser.add_argument(
'--kernel_size',
help='Size of conv kernel. Default: %(default)s',
type=int,
default=3,
metavar='SIZE'
)
parser.add_argument(
'--pad_once',
help='Pad filtered convs only once before filter instead ' + \
'of twice. Default: %(default)s',
type=utils.bool_type,
const=True,
nargs='?',
default=True,
metavar='BOOL'
)
parser.add_argument(
'--pad_mode',
help='Padding mode for conv layers. Default: %(default)s',
type=str,
default='constant',
metavar='MODE'
)
parser.add_argument(
'--pad_constant',
help='Padding constant for conv layers when `pad_mode` is ' + \
'\'constant\'. Default: %(default)s',
type=float,
default=0,
metavar='VALUE'
)
parser.add_argument(
'--filter_pad_mode',
help='Padding mode for filter layers. Default: %(default)s',
type=str,
default='constant',
metavar='MODE'
)
parser.add_argument(
'--filter_pad_constant',
help='Padding constant for filter layers when `filter_pad_mode` ' + \
'is \'constant\'. Default: %(default)s',
type=float,
default=0,
metavar='VALUE'
)
parser.add_argument(
'--filter',
help='Filter to use whenever FIR is applied. Default: %(default)s',
nargs='*',
type=float,
default=[1, 3, 3, 1],
metavar='VALUE'
)
parser.add_argument(
'--weight_scale',
help='Use weight scaling for equalized learning rate. Default: %(default)s',
type=utils.bool_type,
const=True,
nargs='?',
default=True,
metavar='BOOL'
)
#----------------------------------------------------------------------------
# Generator options
parser.add_argument(
'--g_file',
help='Load a generator model from a file instead of constructing a new one. Disabled unless a file is specified.',
type=str,
default=None,
metavar='FILE'
)
parser.add_argument(
'--g_channels',
help='Instead of the values of "--channels", ' + \
'use these for the generator instead.',
nargs='*',
type=int,
default=[],
metavar='CHANNELS'
)
parser.add_argument(
'--g_skip',
help='Use skip connections for the generator. Default: %(default)s',
type=utils.bool_type,
const=True,
nargs='?',
default=True,
metavar='BOOL'
)
parser.add_argument(
'--g_resnet',
help='Use resnet connections for the generator. Default: %(default)s',
type=utils.bool_type,
const=True,
nargs='?',
default=False,
metavar='BOOL'
)
parser.add_argument(
'--g_conv_block_size',
help='Number of layers in a conv block in the generator. Default: %(default)s',
type=int,
default=2,
metavar='VALUE'
)
parser.add_argument(
'--g_normalize',
help='Normalize conv features for generator. Default: %(default)s',
type=utils.bool_type,
const=True,
nargs='?',
default=True,
metavar='BOOL'
)
parser.add_argument(
'--g_fused_conv',
help='Fuse conv & upsample into a transposed ' + \
'conv for the generator. Default: %(default)s',
type=utils.bool_type,
const=True,
nargs='?',
default=True,
metavar='BOOL'
)
parser.add_argument(
'--g_activation',
help='The non-linear activaiton function for ' + \
'the generator. Default: %(default)s',
default='leaky:0.2',
type=str,
metavar='ACTIVATION'
)
parser.add_argument(
'--g_conv_resample_mode',
help='Resample mode for upsampling conv ' + \
'layers for generator. Default: %(default)s',
type=str,
default='FIR',
metavar='MODE'
)
parser.add_argument(
'--g_skip_resample_mode',
help='Resample mode for skip connection ' + \
'upsamples for the generator. Default: %(default)s',
type=str,
default='FIR',
metavar='MODE'
)
parser.add_argument(
'--g_lr',
help='The learning rate for the generator. Default: %(default)s',
default=2e-3,
type=float,
metavar='VALUE'
)
parser.add_argument(
'--g_betas',
help='Beta values for the generator Adam optimizer. Default: %(default)s',
type=float,
nargs=2,
default=(0, 0.99),
metavar='VALUE'
)
parser.add_argument(
'--g_loss',
help='Loss function for the generator. Default: %(default)s',
default='logistic_ns',
type=str,
metavar='LOSS'
)
parser.add_argument(
'--g_reg',
help='Regularization function for the generator with an optional weight (:?). Default: %(default)s',
default='pathreg:2',
type=str,
metavar='REG'
)
parser.add_argument(
'--g_reg_interval',
help='Interval at which to regularize the generator. Default: %(default)s',
default=4,
type=int,
metavar='INTERVAL'
)
parser.add_argument(
'--g_iter',
help='Number of generator iterations per training iteration. Default: %(default)s',
default=1,
type=int,
metavar='ITER'
)
parser.add_argument(
'--style_mix',
help='The probability of passing more than one ' + \
'latent to the generator. Default: %(default)s',
type=float,
default=0.9,
metavar='PROBABILITY'
)
parser.add_argument(
'--latent_mapping_layers',
help='The number of layers of the latent mapping network. Default: %(default)s',
type=int,
default=8,
metavar='LAYERS'
)
parser.add_argument(
'--latent_mapping_lr_mul',
help='The learning rate multiplier for the latent ' + \
'mapping network. Default: %(default)s',
type=float,
default=0.01,
metavar='LR_MUL'
)
parser.add_argument(
'--normalize_latent',
help='Normalize latent inputs. Default: %(default)s',
type=utils.bool_type,
const=True,
nargs='?',
default=True,
metavar='BOOL'
)
parser.add_argument(
'--modulate_rgb',
help='Modulate RGB layers (use style for output ' + \
'layers of generator). Default: %(default)s',
type=utils.bool_type,
const=True,
nargs='?',
default=True,
metavar='BOOL'
)
#----------------------------------------------------------------------------
# Discriminator options
parser.add_argument(
'--d_file',
help='Load a discriminator model from a file instead of constructing a new one. Disabled unless a file is specified.',
type=str,
default=None,
metavar='FILE'
)
parser.add_argument(
'--d_channels',
help='Instead of the values of "--channels", ' + \
'use these for the discriminator instead.',
nargs='*',
type=int,
default=[],
metavar='CHANNELS'
)
parser.add_argument(
'--d_skip',
help='Use skip connections for the discriminator. Default: %(default)s',
type=utils.bool_type,
const=True,
nargs='?',
default=False,
metavar='BOOL'
)
parser.add_argument(
'--d_resnet',
help='Use resnet connections for the discriminator. Default: %(default)s',
type=utils.bool_type,
const=True,
nargs='?',
default=True,
metavar='BOOL'
)
parser.add_argument(
'--d_conv_block_size',
help='Number of layers in a conv block in the discriminator. Default: %(default)s',
type=int,
default=2,
metavar='VALUE'
)
parser.add_argument(
'--d_fused_conv',
help='Fuse conv & downsample into a strided ' + \
'conv for the discriminator. Default: %(default)s',
type=utils.bool_type,
const=True,
nargs='?',
default=True,
metavar='BOOL'
)
parser.add_argument(
'--group_size',
help='Size of the groups in batch std layer. Default: %(default)s',
type=int,
default=4,
metavar='VALUE'
)
parser.add_argument(
'--d_activation',
help='The non-linear activaiton function for the discriminator. Default: %(default)s',
default='leaky:0.2',
type=str,
metavar='ACTIVATION'
)
parser.add_argument(
'--d_conv_resample_mode',
help='Resample mode for downsampling conv ' + \
'layers for discriminator. Default: %(default)s',
type=str,
default='FIR',
metavar='MODE'
)
parser.add_argument(
'--d_skip_resample_mode',
help='Resample mode for skip connection ' + \
'downsamples for the discriminator. Default: %(default)s',
type=str,
default='FIR',
metavar='MODE'
)
parser.add_argument(
'--d_loss',
help='Loss function for the disriminator. Default: %(default)s',
default='logistic',
type=str,
metavar='LOSS'
)
parser.add_argument(
'--d_reg',
help='Regularization function for the discriminator ' + \
'with an optional weight (:?). Default: %(default)s',
default='r1:10',
type=str,
metavar='REG'
)
parser.add_argument(
'--d_reg_interval',
help='Interval at which to regularize the discriminator. Default: %(default)s',
default=16,
type=int,
metavar='INTERVAL'
)
parser.add_argument(
'--d_iter',
help='Number of discriminator iterations per training iteration. Default: %(default)s',
default=1,
type=int,
metavar='ITER'
)
parser.add_argument(
'--d_lr',
help='The learning rate for the discriminator. Default: %(default)s',
default=2e-3,
type=float,
metavar='VALUE'
)
parser.add_argument(
'--d_betas',
help='Beta values for the discriminator Adam optimizer. Default: %(default)s',
type=float,
nargs=2,
default=(0, 0.99),
metavar='VALUE'
)
#----------------------------------------------------------------------------
# Training options
parser.add_argument(
'--iterations',
help='Number of iterations to train for. Default: %(default)s',
type=int,
default=1000000,
metavar='ITERATIONS'
)
parser.add_argument(
'--gpu',
help='The cuda device(s) to use. Example: ""--gpu 0 1" will train ' + \
'on GPU 0 and GPU 1. Default: Only use CPU',
type=int,
default=[],
nargs='*',
metavar='DEVICE_ID'
)
parser.add_argument(
'--distributed',
help='When more than one gpu device is passed, automatically ' + \
'start one process for each device and give it the correct ' + \
'distributed args (rank, world_size etc). Disable this if ' + \
'you want training to be performed with only one process ' + \
'using the DataParallel module. Default: %(default)s',
type=utils.bool_type,
const=True,
nargs='?',
default=True,
metavar='BOOL'
)
parser.add_argument(
'--rank',
help='Rank for distributed training.',
type=int,
default=None,
)
parser.add_argument(
'--world_size',
help='World size for distributed training.',
type=int,
default=None,
)
parser.add_argument(
'--master_addr',
help='Address for distributed training.',
type=str,
default=None,
)
parser.add_argument(
'--master_port',
help='Port for distributed training.',
type=str,
default=None,
)
parser.add_argument(
'--batch_size',
help='Size of each batch. Default: %(default)s',
default=32,
type=int,
metavar='VALUE'
)
parser.add_argument(
'--device_batch_size',
help='Maximum number of items to fit on single device at a time. Default: %(default)s',
default=4,
type=int,
metavar='VALUE'
)
parser.add_argument(
'--g_reg_batch_size',
help='Size of each batch used to regularize the generator. Default: %(default)s',
default=16,
type=int,
metavar='VALUE'
)
parser.add_argument(
'--g_reg_device_batch_size',
help='Maximum number of items to fit on single device when ' + \
'regularizing the generator. Default: %(default)s',
default=2,
type=int,
metavar='VALUE'
)
parser.add_argument(
'--half',
help='Use mixed precision training. Default: %(default)s',
type=utils.bool_type,
const=True,
nargs='?',
default=False,
metavar='BOOL'
)
parser.add_argument(
'--resume',
help='Resume from the latest saved checkpoint in the checkpoint_dir. ' + \
'This loads all previous training settings except for the dataset options, ' + \
'device args (--gpu ...) and distributed training args (--rank, --world_size e.t.c) ' + \
'as well as metrics and logging.',
type=utils.bool_type,
const=True,
nargs='?',
default=False,
metavar='BOOL'
)
#----------------------------------------------------------------------------
# Extra metric options
parser.add_argument(
'--fid_interval',
help='If specified, evaluate the FID metric with this interval.',
default=None,
type=int,
metavar='INTERVAL'
)
parser.add_argument(
'--ppl_interval',
help='If specified, evaluate the PPL metric with this interval.',
default=None,
type=int,
metavar='INTERVAL'
)
parser.add_argument(
'--ppl_ffhq_crop',
help='Crop images evaluated for PPL with crop values for FFHQ. Default: %(default)s',
type=utils.bool_type,
const=True,
nargs='?',
default=False,
metavar='BOOL'
)
#----------------------------------------------------------------------------
# Data options
parser.add_argument(
'--pixel_min',
help='Minimum of the value range of pixels in generated images. Default: %(default)s',
default=-1,
type=float,
metavar='VALUE'
)
parser.add_argument(
'--pixel_max',
help='Maximum of the value range of pixels in generated images. Default: %(default)s',
default=1,
type=float,
metavar='VALUE'
)
parser.add_argument(
'--data_channels',
help='Number of channels in the data. Default: 3 (RGB)',
default=3,
type=int,
choices=[1, 3],
metavar='CHANNELS'
)
parser.add_argument(
'--data_dir',
help='The root directory of the dataset. This argument is required!',
type=str,
default=None
)
parser.add_argument(
'--data_resize',
help='Resize data to fit input size of discriminator. Default: %(default)s',
type=utils.bool_type,
const=True,
nargs='?',
default=False,
metavar='BOOL'
)
parser.add_argument(
'--mirror_augment',
help='Use random horizontal flipping for data images. Default: %(default)s',
type=utils.bool_type,
const=True,
nargs='?',
default=False,
metavar='BOOL'
)
parser.add_argument(
'--data_workers',
help='Number of worker processes that handles dataloading. Default: %(default)s',
default=4,
type=int,
metavar='WORKERS'
)
#----------------------------------------------------------------------------
# Logging options
parser.add_argument(
'--checkpoint_dir',
help='If specified, save checkpoints to this directory.',
default=None,
type=str,
metavar='DIR'
)
parser.add_argument(
'--checkpoint_interval',
help='Save checkpoints with this interval. Default: %(default)s',
default=10000,
type=int,
metavar='INTERVAL'
)
parser.add_argument(
'--tensorboard_log_dir',
help='Log to this tensorboard directory if specified.',
default=None,
type=str,
metavar='DIR'
)
parser.add_argument(
'--tensorboard_image_interval',
help='Log images to tensorboard with this interval if specified.',
default=None,
type=int,
metavar='INTERVAL'
)
parser.add_argument(
'--tensorboard_image_size',
help='Size of images logged to tensorboard. Default: %(default)s',
default=256,
type=int,
metavar='VALUE'
)
return parser
#----------------------------------------------------------------------------
def get_dataset(args):
assert args.data_dir, '--data_dir has to be specified.'
height, width = [
shape * 2 ** (len(args.d_channels or args.channels) - 1)
for shape in args.base_shape
]
dataset = utils.ImageFolder(
args.data_dir,
mirror=args.mirror_augment,
pixel_min=args.pixel_min,
pixel_max=args.pixel_max,
height=height,
width=width,
resize=args.data_resize,
grayscale=args.data_channels == 1
)
assert len(dataset), 'No images found at {}'.format(args.data_dir)
return dataset
#----------------------------------------------------------------------------
def get_models(args):
common_kwargs = dict(
data_channels=args.data_channels,
base_shape=args.base_shape,
conv_filter=args.filter,
skip_filter=args.filter,
kernel_size=args.kernel_size,
conv_pad_mode=args.pad_mode,
conv_pad_constant=args.pad_constant,
filter_pad_mode=args.filter_pad_mode,
filter_pad_constant=args.filter_pad_constant,
pad_once=args.pad_once,
weight_scale=args.weight_scale
)
if args.g_file:
G = stylegan2.models.load(args.g_file)
assert isinstance(G, stylegan2.models.Generator), \
'`--g_file` should specify a generator model, found {}'.format(type(G))
else:
G_M = stylegan2.models.GeneratorMapping(
latent_size=args.latent,
label_size=args.label,
num_layers=args.latent_mapping_layers,
hidden=args.latent,
activation=args.g_activation,
normalize_input=args.normalize_latent,
lr_mul=args.latent_mapping_lr_mul,
weight_scale=args.weight_scale
)
G_S = stylegan2.models.GeneratorSynthesis(
channels=args.g_channels or args.channels,
latent_size=args.latent,
demodulate=args.g_normalize,
modulate_data_out=args.modulate_rgb,
conv_block_size=args.g_conv_block_size,
activation=args.g_activation,
conv_resample_mode=args.g_conv_resample_mode,
skip_resample_mode=args.g_skip_resample_mode,
resnet=args.g_resnet,
skip=args.g_skip,
fused_resample=args.g_fused_conv,
**common_kwargs
)
G = stylegan2.models.Generator(G_mapping=G_M, G_synthesis=G_S)
if args.d_file:
D = stylegan2.models.load(args.d_file)
assert isinstance(D, stylegan2.models.Discriminator), \
'`--d_file` should specify a discriminator model, found {}'.format(type(D))
else:
D = stylegan2.models.Discriminator(
channels=args.d_channels or args.channels,
label_size=args.label,
conv_block_size=args.d_conv_block_size,
activation=args.d_activation,
conv_resample_mode=args.d_conv_resample_mode,
skip_resample_mode=args.d_skip_resample_mode,
mbstd_group_size=args.group_size,
resnet=args.d_resnet,
skip=args.d_skip,
fused_resample=args.d_fused_conv,
**common_kwargs
)
assert len(G.G_synthesis.channels) == len(D.channels), \
'While the number of channels for each layer can ' + \
'differ between generator and discriminator, the ' + \
'number of layers have to be the same. Received ' + \
'{} generator layers and {} discriminator layers.'.format(
len(G.G_synthesis.channels), len(D.channels))
return G, D
#----------------------------------------------------------------------------
def get_trainer(args):
dataset = get_dataset(args)
if args.resume and stylegan2.train._find_checkpoint(args.checkpoint_dir):
trainer = stylegan2.train.Trainer.load_checkpoint(
args.checkpoint_dir,
dataset,
device=args.gpu,
rank=args.rank,
world_size=args.world_size,
master_addr=args.master_addr,
master_port=args.master_port,
tensorboard_log_dir=args.tensorboard_log_dir
)
else:
G, D = get_models(args)
trainer = stylegan2.train.Trainer(
G=G,
D=D,
latent_size=args.latent,
dataset=dataset,
device=args.gpu,
batch_size=args.batch_size,
device_batch_size=args.device_batch_size,
label_size=args.label,
data_workers=args.data_workers,
G_loss=args.g_loss,
D_loss=args.d_loss,
G_reg=args.g_reg,
G_reg_interval=args.g_reg_interval,
G_opt_kwargs={'lr': args.g_lr, 'betas': args.g_betas},
G_reg_batch_size=args.g_reg_batch_size,
G_reg_device_batch_size=args.g_reg_device_batch_size,
D_reg=args.d_reg,
D_reg_interval=args.d_reg_interval,
D_opt_kwargs={'lr': args.d_lr, 'betas': args.d_betas},
style_mix_prob=args.style_mix,
G_iter=args.g_iter,
D_iter=args.d_iter,
tensorboard_log_dir=args.tensorboard_log_dir,
checkpoint_dir=args.checkpoint_dir,
checkpoint_interval=args.checkpoint_interval,
half=args.half,
rank=args.rank,
world_size=args.world_size,
master_addr=args.master_addr,
master_port=args.master_port
)
if args.fid_interval and not args.rank:
fid_model = inception.InceptionV3FeatureExtractor(
pixel_min=args.pixel_min, pixel_max=args.pixel_max)
trainer.register_metric(
name='FID (299x299)',
eval_fn=fid.FID(
trainer.Gs,
trainer.prior_generator,
dataset=dataset,
fid_model=fid_model,
fid_size=299,
reals_batch_size=64
),
interval=args.fid_interval
)
trainer.register_metric(
name='FID',
eval_fn=fid.FID(
trainer.Gs,
trainer.prior_generator,
dataset=dataset,
fid_model=fid_model,
fid_size=None
),
interval=args.fid_interval
)
if args.ppl_interval and not args.rank:
lpips_model = lpips.LPIPS_VGG16(
pixel_min=args.pixel_min, pixel_max=args.pixel_max)
crop = None
if args.ppl_ffhq_crop:
crop = ppl.PPL.FFHQ_CROP
trainer.register_metric(
name='PPL_end',
eval_fn=ppl.PPL(
trainer.Gs,
trainer.prior_generator,
full_sampling=False,
crop=crop,
lpips_model=lpips_model,
lpips_size=256
),
interval=args.ppl_interval
)
trainer.register_metric(
name='PPL_full',
eval_fn=ppl.PPL(
trainer.Gs,
trainer.prior_generator,
full_sampling=True,
crop=crop,
lpips_model=lpips_model,
lpips_size=256
),
interval=args.ppl_interval
)
if args.tensorboard_image_interval:
for static in [True, False]:
for trunc in [0.5, 0.7, 1.0]:
if static:
name = 'static'
else:
name = 'random'
name += '/trunc_{:.1f}'.format(trunc)
trainer.add_tensorboard_image_logging(
name=name,
num_images=4,
interval=args.tensorboard_image_interval,
resize=args.tensorboard_image_size,
seed=1234567890 if static else None,
truncation_psi=trunc,
pixel_min=args.pixel_min,
pixel_max=args.pixel_max
)
return trainer
#----------------------------------------------------------------------------
def run(args):
if not args.rank:
if not (args.checkpoint_dir or args.output):
warnings.warn(
'Neither an output path or checkpoint dir has been ' + \
'given. Weights from this training run will never ' + \
'be saved.'
)
if args.output:
assert os.path.isdir(args.output) or not os.path.splitext(args.output)[-1], \
'--output argument should specify a directory, not a file.'
trainer = get_trainer(args)
trainer.train(iterations=args.iterations)
if not args.rank and args.output:
print('Saving models to {}'.format(args.output))
if not os.path.exists(args.output):
os.makedirs(args.output)
for model_name in ['G', 'D', 'Gs']:
getattr(trainer, model_name).save(
os.path.join(args.output_dir, model_name + '.pth'))
#----------------------------------------------------------------------------
def run_distributed(rank, args):
args.rank = rank
args.world_size = len(args.gpu)
args.gpu = args.gpu[rank]
args.master_addr = args.master_addr or '127.0.0.1'
args.master_port = args.master_port or '23456'
run(args)
#----------------------------------------------------------------------------
def main():
parser = get_arg_parser()
args = parser.parse_args()
if len(args.gpu) > 1 and args.distributed:
assert args.rank is None and args.world_size is None, \
'When --distributed is enabled (default) the rank and ' + \
'world size can not be given as this is set up automatically. ' + \
'Use --distributed 0 to disable automatic setup of distributed training.'
mp.spawn(run_distributed, nprocs=len(args.gpu), args=(args,))
else:
run(args)
#----------------------------------------------------------------------------
if __name__ == '__main__':
main()