szyezhu's picture
Upload 46 files
019d164
import argparse
import traceback
import logging
import yaml
import sys
import os
import torch
import numpy as np
from boundarydiffusion import BoundaryDiffusion
from configs.paths_config import HYBRID_MODEL_PATHS
def parse_args_and_config():
parser = argparse.ArgumentParser(description=globals()['__doc__'])
# Mode
parser.add_argument('--radius', action='store_true')
parser.add_argument('--unconditional', action='store_true')
parser.add_argument('--boundary_search', action='store_true')
parser.add_argument('--diffusion_hyperplane', action='store_true')
parser.add_argument('--clip_finetune', action='store_true')
parser.add_argument('--clip_latent_optim', action='store_true')
parser.add_argument('--edit_images_from_dataset', action='store_true')
parser.add_argument('--edit_one_image', action='store_true')
parser.add_argument('--unseen2unseen', action='store_true')
parser.add_argument('--clip_finetune_eff', action='store_true')
parser.add_argument('--edit_one_image_eff', action='store_true')
parser.add_argument('--edit_image_boundary', action='store_true')
# Default
parser.add_argument('--config', type=str, required=True, help='Path to the config file')
parser.add_argument('--seed', type=int, default=1006, help='Random seed')
parser.add_argument('--exp', type=str, default='./runs/', help='Path for saving running related data.')
parser.add_argument('--comment', type=str, default='', help='A string for experiment comment')
parser.add_argument('--verbose', type=str, default='info', help='Verbose level: info | debug | warning | critical')
parser.add_argument('--ni', type=int, default=1, help="No interaction. Suitable for Slurm Job launcher")
parser.add_argument('--align_face', type=int, default=1, help='align face or not')
# Text
parser.add_argument('--edit_attr', type=str, default=None, help='Attribute to edit defiend in ./utils/text_dic.py')
parser.add_argument('--src_txts', type=str, action='append', help='Source text e.g. Face')
parser.add_argument('--trg_txts', type=str, action='append', help='Target text e.g. Angry Face')
parser.add_argument('--target_class_num', type=str, default=None)
# Sampling
parser.add_argument('--t_0', type=int, default=400, help='Return step in [0, 1000)')
parser.add_argument('--n_inv_step', type=int, default=40, help='# of steps during generative pross for inversion')
parser.add_argument('--n_train_step', type=int, default=6, help='# of steps during generative pross for train')
parser.add_argument('--n_test_step', type=int, default=40, help='# of steps during generative pross for test')
parser.add_argument('--sample_type', type=str, default='ddim', help='ddpm for Markovian sampling, ddim for non-Markovian sampling')
parser.add_argument('--eta', type=float, default=0.0, help='Controls of varaince of the generative process')
parser.add_argument('--start_distance', type=float, default=-150.0, help='Starting distance of the editing space')
parser.add_argument('--end_distance', type=float, default=150.0, help='Ending distance of the editing space')
parser.add_argument('--edit_img_number', type=int, default=20, help='Number of editing images')
# Train & Test
parser.add_argument('--do_train', type=int, default=1, help='Whether to train or not during CLIP finetuning')
parser.add_argument('--do_test', type=int, default=1, help='Whether to test or not during CLIP finetuning')
parser.add_argument('--save_train_image', type=int, default=1, help='Wheter to save training results during CLIP fineuning')
parser.add_argument('--bs_train', type=int, default=1, help='Training batch size during CLIP fineuning')
parser.add_argument('--bs_test', type=int, default=1, help='Test batch size during CLIP fineuning')
parser.add_argument('--n_precomp_img', type=int, default=100, help='# of images to precompute latents')
parser.add_argument('--n_train_img', type=int, default=50, help='# of training images')
parser.add_argument('--n_test_img', type=int, default=10, help='# of test images')
parser.add_argument('--model_path', type=str, default=None, help='Test model path')
parser.add_argument('--img_path', type=str, default=None, help='Image path to test')
parser.add_argument('--deterministic_inv', type=int, default=1, help='Whether to use deterministic inversion during inference')
parser.add_argument('--hybrid_noise', type=int, default=0, help='Whether to change multiple attributes by mixing multiple models')
parser.add_argument('--model_ratio', type=float, default=1, help='Degree of change, noise ratio from original and finetuned model.')
# Loss & Optimization
parser.add_argument('--clip_loss_w', type=int, default=0, help='Weights of CLIP loss')
parser.add_argument('--l1_loss_w', type=float, default=0, help='Weights of L1 loss')
parser.add_argument('--id_loss_w', type=float, default=0, help='Weights of ID loss')
parser.add_argument('--clip_model_name', type=str, default='ViT-B/16', help='ViT-B/16, ViT-B/32, RN50x16 etc')
parser.add_argument('--lr_clip_finetune', type=float, default=2e-6, help='Initial learning rate for finetuning')
parser.add_argument('--lr_clip_lat_opt', type=float, default=2e-2, help='Initial learning rate for latent optim')
parser.add_argument('--n_iter', type=int, default=1, help='# of iterations of a generative process with `n_train_img` images')
parser.add_argument('--scheduler', type=int, default=1, help='Whether to increase the learning rate')
parser.add_argument('--sch_gamma', type=float, default=1.3, help='Scheduler gamma')
args = parser.parse_args()
# parse config file
with open(os.path.join('configs', args.config), 'r') as f:
config = yaml.safe_load(f)
new_config = dict2namespace(config)
if args.diffusion_hyperplane:
if args.edit_attr is not None:
args.exp = args.exp + f'_SP_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
else:
args.exp = args.exp + f'_SP_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
elif args.radius:
if args.edit_attr is not None:
args.exp = args.exp + f'_R_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
else:
args.exp = args.exp + f'_R_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
elif args.unconditional:
if args.edit_attr is not None:
args.exp = args.exp + f'_UN_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
else:
args.exp = args.exp + f'_UN_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
elif args.boundary_search:
if args.edit_attr is not None:
args.exp = args.exp + f'_BCLIP_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
else:
args.exp = args.exp + f'_BCLIP_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
elif args.clip_finetune or args.clip_finetune_eff :
if args.edit_attr is not None:
args.exp = args.exp + f'_FT_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
else:
args.exp = args.exp + f'_FT_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
elif args.clip_latent_optim:
if args.edit_attr is not None:
args.exp = args.exp + f'_LO_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_lat_opt}'
else:
args.exp = args.exp + f'_LO_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_lat_opt}'
elif args.edit_images_from_dataset:
if args.model_path:
args.exp = args.exp + f'_ED_{new_config.data.category}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_{os.path.split(args.model_path)[-1].replace(".pth","")}'
elif args.hybrid_noise:
hb_str = '_'
for i, model_name in enumerate(HYBRID_MODEL_PATHS):
hb_str = hb_str + model_name.split('_')[1]
if i != len(HYBRID_MODEL_PATHS) - 1:
hb_str = hb_str + '_'
args.exp = args.exp + f'_ED_{new_config.data.category}_t{args.t_0}_ninv{args.n_train_step}_ngen{args.n_train_step}' + hb_str
else:
args.exp = args.exp + f'_ED_{new_config.data.category}_t{args.t_0}_ninv{args.n_train_step}_ngen{args.n_train_step}_orig'
elif args.edit_image_boundary:
if args.model_path:
args.exp = args.exp + f'_E1_t{args.t_0}_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_inv_step}_{os.path.split(args.model_path)[-1].replace(".pth", "")}'
elif args.hybrid_noise:
hb_str = '_'
for i, model_name in enumerate(HYBRID_MODEL_PATHS):
hb_str = hb_str + model_name.split('_')[1]
if i != len(HYBRID_MODEL_PATHS) - 1:
hb_str = hb_str + '_'
args.exp = args.exp + f'_E1_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}' + hb_str
else:
args.exp = args.exp + f'_E1_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}_orig'
elif args.edit_one_image:
if args.model_path:
args.exp = args.exp + f'_E1_t{args.t_0}_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_inv_step}_{os.path.split(args.model_path)[-1].replace(".pth", "")}'
elif args.hybrid_noise:
hb_str = '_'
for i, model_name in enumerate(HYBRID_MODEL_PATHS):
hb_str = hb_str + model_name.split('_')[1]
if i != len(HYBRID_MODEL_PATHS) - 1:
hb_str = hb_str + '_'
args.exp = args.exp + f'_E1_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}' + hb_str
else:
args.exp = args.exp + f'_E1_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}_orig'
elif args.unseen2unseen:
if args.model_path:
args.exp = args.exp + f'_U2U_t{args.t_0}_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_{os.path.split(args.model_path)[-1].replace(".pth", "")}'
elif args.hybrid_noise:
hb_str = '_'
for i, model_name in enumerate(HYBRID_MODEL_PATHS):
hb_str = hb_str + model_name.split('_')[1]
if i != len(HYBRID_MODEL_PATHS) - 1:
hb_str = hb_str + '_'
args.exp = args.exp + f'_U2U_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}_ngen{args.n_train_step}' + hb_str
else:
args.exp = args.exp + f'_U2U_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}_ngen{args.n_train_step}_orig'
elif args.recon_exp:
args.exp = args.exp + f'_REC_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}'
elif args.find_best_image:
args.exp = args.exp + f'_FOpt_{new_config.data.category}_{args.trg_txts[0]}_t{args.t_0}_ninv{args.n_train_step}'
level = getattr(logging, args.verbose.upper(), None)
if not isinstance(level, int):
raise ValueError('level {} not supported'.format(args.verbose))
handler1 = logging.StreamHandler()
formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s')
handler1.setFormatter(formatter)
logger = logging.getLogger()
logger.addHandler(handler1)
logger.setLevel(level)
os.makedirs(args.exp, exist_ok=True)
os.makedirs('checkpoint', exist_ok=True)
os.makedirs('precomputed', exist_ok=True)
os.makedirs('runs', exist_ok=True)
os.makedirs(args.exp, exist_ok=True)
args.image_folder = os.path.join(args.exp, 'image_samples')
if not os.path.exists(args.image_folder):
os.makedirs(args.image_folder)
else:
overwrite = False
if args.ni:
overwrite = True
else:
response = input("Image folder already exists. Overwrite? (Y/N)")
if response.upper() == 'Y':
overwrite = True
if overwrite:
# shutil.rmtree(args.image_folder)
os.makedirs(args.image_folder, exist_ok=True)
else:
print("Output image folder exists. Program halted.")
sys.exit(0)
# add device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
logging.info("Using device: {}".format(device))
new_config.device = device
# set random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = True
return args, new_config
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
def main():
args, config = parse_args_and_config()
print(">" * 80)
logging.info("Exp instance id = {}".format(os.getpid()))
logging.info("Exp comment = {}".format(args.comment))
logging.info("Config =")
print("<" * 80)
runner = BoundaryDiffusion(args, config)
try:
if args.clip_finetune:
runner.clip_finetune()
elif args.radius:
runner.radius()
elif args.unconditional:
runner.unconditional()
elif args.diffusion_hyperplane:
runner.diffusion_hyperplane()
elif args.boundary_search:
runner.boundary_search()
elif args.edit_image_boundary:
runner.edit_image_boundary()
else:
print('Choose one mode!')
raise ValueError
except Exception:
logging.error(traceback.format_exc())
return 0
if __name__ == '__main__':
sys.exit(main())