|
import argparse |
|
import traceback |
|
import logging |
|
import yaml |
|
import sys |
|
import os |
|
import torch |
|
import numpy as np |
|
|
|
from boundarydiffusion import BoundaryDiffusion |
|
from configs.paths_config import HYBRID_MODEL_PATHS |
|
|
|
def parse_args_and_config(): |
|
parser = argparse.ArgumentParser(description=globals()['__doc__']) |
|
|
|
|
|
parser.add_argument('--radius', action='store_true') |
|
parser.add_argument('--unconditional', action='store_true') |
|
parser.add_argument('--boundary_search', action='store_true') |
|
parser.add_argument('--diffusion_hyperplane', action='store_true') |
|
parser.add_argument('--clip_finetune', action='store_true') |
|
parser.add_argument('--clip_latent_optim', action='store_true') |
|
parser.add_argument('--edit_images_from_dataset', action='store_true') |
|
parser.add_argument('--edit_one_image', action='store_true') |
|
parser.add_argument('--unseen2unseen', action='store_true') |
|
parser.add_argument('--clip_finetune_eff', action='store_true') |
|
parser.add_argument('--edit_one_image_eff', action='store_true') |
|
parser.add_argument('--edit_image_boundary', action='store_true') |
|
|
|
|
|
parser.add_argument('--config', type=str, required=True, help='Path to the config file') |
|
parser.add_argument('--seed', type=int, default=1006, help='Random seed') |
|
parser.add_argument('--exp', type=str, default='./runs/', help='Path for saving running related data.') |
|
parser.add_argument('--comment', type=str, default='', help='A string for experiment comment') |
|
parser.add_argument('--verbose', type=str, default='info', help='Verbose level: info | debug | warning | critical') |
|
parser.add_argument('--ni', type=int, default=1, help="No interaction. Suitable for Slurm Job launcher") |
|
parser.add_argument('--align_face', type=int, default=1, help='align face or not') |
|
|
|
|
|
parser.add_argument('--edit_attr', type=str, default=None, help='Attribute to edit defiend in ./utils/text_dic.py') |
|
parser.add_argument('--src_txts', type=str, action='append', help='Source text e.g. Face') |
|
parser.add_argument('--trg_txts', type=str, action='append', help='Target text e.g. Angry Face') |
|
parser.add_argument('--target_class_num', type=str, default=None) |
|
|
|
|
|
parser.add_argument('--t_0', type=int, default=400, help='Return step in [0, 1000)') |
|
parser.add_argument('--n_inv_step', type=int, default=40, help='# of steps during generative pross for inversion') |
|
parser.add_argument('--n_train_step', type=int, default=6, help='# of steps during generative pross for train') |
|
parser.add_argument('--n_test_step', type=int, default=40, help='# of steps during generative pross for test') |
|
parser.add_argument('--sample_type', type=str, default='ddim', help='ddpm for Markovian sampling, ddim for non-Markovian sampling') |
|
parser.add_argument('--eta', type=float, default=0.0, help='Controls of varaince of the generative process') |
|
parser.add_argument('--start_distance', type=float, default=-150.0, help='Starting distance of the editing space') |
|
parser.add_argument('--end_distance', type=float, default=150.0, help='Ending distance of the editing space') |
|
parser.add_argument('--edit_img_number', type=int, default=20, help='Number of editing images') |
|
|
|
|
|
parser.add_argument('--do_train', type=int, default=1, help='Whether to train or not during CLIP finetuning') |
|
parser.add_argument('--do_test', type=int, default=1, help='Whether to test or not during CLIP finetuning') |
|
parser.add_argument('--save_train_image', type=int, default=1, help='Wheter to save training results during CLIP fineuning') |
|
parser.add_argument('--bs_train', type=int, default=1, help='Training batch size during CLIP fineuning') |
|
parser.add_argument('--bs_test', type=int, default=1, help='Test batch size during CLIP fineuning') |
|
parser.add_argument('--n_precomp_img', type=int, default=100, help='# of images to precompute latents') |
|
parser.add_argument('--n_train_img', type=int, default=50, help='# of training images') |
|
parser.add_argument('--n_test_img', type=int, default=10, help='# of test images') |
|
parser.add_argument('--model_path', type=str, default=None, help='Test model path') |
|
parser.add_argument('--img_path', type=str, default=None, help='Image path to test') |
|
parser.add_argument('--deterministic_inv', type=int, default=1, help='Whether to use deterministic inversion during inference') |
|
parser.add_argument('--hybrid_noise', type=int, default=0, help='Whether to change multiple attributes by mixing multiple models') |
|
parser.add_argument('--model_ratio', type=float, default=1, help='Degree of change, noise ratio from original and finetuned model.') |
|
|
|
|
|
|
|
parser.add_argument('--clip_loss_w', type=int, default=0, help='Weights of CLIP loss') |
|
parser.add_argument('--l1_loss_w', type=float, default=0, help='Weights of L1 loss') |
|
parser.add_argument('--id_loss_w', type=float, default=0, help='Weights of ID loss') |
|
parser.add_argument('--clip_model_name', type=str, default='ViT-B/16', help='ViT-B/16, ViT-B/32, RN50x16 etc') |
|
parser.add_argument('--lr_clip_finetune', type=float, default=2e-6, help='Initial learning rate for finetuning') |
|
parser.add_argument('--lr_clip_lat_opt', type=float, default=2e-2, help='Initial learning rate for latent optim') |
|
parser.add_argument('--n_iter', type=int, default=1, help='# of iterations of a generative process with `n_train_img` images') |
|
parser.add_argument('--scheduler', type=int, default=1, help='Whether to increase the learning rate') |
|
parser.add_argument('--sch_gamma', type=float, default=1.3, help='Scheduler gamma') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
with open(os.path.join('configs', args.config), 'r') as f: |
|
config = yaml.safe_load(f) |
|
new_config = dict2namespace(config) |
|
|
|
if args.diffusion_hyperplane: |
|
if args.edit_attr is not None: |
|
args.exp = args.exp + f'_SP_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}' |
|
else: |
|
args.exp = args.exp + f'_SP_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}' |
|
elif args.radius: |
|
if args.edit_attr is not None: |
|
args.exp = args.exp + f'_R_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}' |
|
else: |
|
args.exp = args.exp + f'_R_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}' |
|
elif args.unconditional: |
|
if args.edit_attr is not None: |
|
args.exp = args.exp + f'_UN_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}' |
|
else: |
|
args.exp = args.exp + f'_UN_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}' |
|
elif args.boundary_search: |
|
if args.edit_attr is not None: |
|
args.exp = args.exp + f'_BCLIP_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}' |
|
else: |
|
args.exp = args.exp + f'_BCLIP_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}' |
|
elif args.clip_finetune or args.clip_finetune_eff : |
|
if args.edit_attr is not None: |
|
args.exp = args.exp + f'_FT_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}' |
|
else: |
|
args.exp = args.exp + f'_FT_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}' |
|
elif args.clip_latent_optim: |
|
if args.edit_attr is not None: |
|
args.exp = args.exp + f'_LO_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_lat_opt}' |
|
else: |
|
args.exp = args.exp + f'_LO_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_lat_opt}' |
|
elif args.edit_images_from_dataset: |
|
if args.model_path: |
|
args.exp = args.exp + f'_ED_{new_config.data.category}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_{os.path.split(args.model_path)[-1].replace(".pth","")}' |
|
elif args.hybrid_noise: |
|
hb_str = '_' |
|
for i, model_name in enumerate(HYBRID_MODEL_PATHS): |
|
hb_str = hb_str + model_name.split('_')[1] |
|
if i != len(HYBRID_MODEL_PATHS) - 1: |
|
hb_str = hb_str + '_' |
|
args.exp = args.exp + f'_ED_{new_config.data.category}_t{args.t_0}_ninv{args.n_train_step}_ngen{args.n_train_step}' + hb_str |
|
else: |
|
args.exp = args.exp + f'_ED_{new_config.data.category}_t{args.t_0}_ninv{args.n_train_step}_ngen{args.n_train_step}_orig' |
|
|
|
elif args.edit_image_boundary: |
|
if args.model_path: |
|
args.exp = args.exp + f'_E1_t{args.t_0}_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_inv_step}_{os.path.split(args.model_path)[-1].replace(".pth", "")}' |
|
elif args.hybrid_noise: |
|
hb_str = '_' |
|
for i, model_name in enumerate(HYBRID_MODEL_PATHS): |
|
hb_str = hb_str + model_name.split('_')[1] |
|
if i != len(HYBRID_MODEL_PATHS) - 1: |
|
hb_str = hb_str + '_' |
|
args.exp = args.exp + f'_E1_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}' + hb_str |
|
else: |
|
args.exp = args.exp + f'_E1_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}_orig' |
|
|
|
|
|
elif args.edit_one_image: |
|
if args.model_path: |
|
args.exp = args.exp + f'_E1_t{args.t_0}_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_inv_step}_{os.path.split(args.model_path)[-1].replace(".pth", "")}' |
|
elif args.hybrid_noise: |
|
hb_str = '_' |
|
for i, model_name in enumerate(HYBRID_MODEL_PATHS): |
|
hb_str = hb_str + model_name.split('_')[1] |
|
if i != len(HYBRID_MODEL_PATHS) - 1: |
|
hb_str = hb_str + '_' |
|
args.exp = args.exp + f'_E1_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}' + hb_str |
|
else: |
|
args.exp = args.exp + f'_E1_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}_orig' |
|
|
|
elif args.unseen2unseen: |
|
if args.model_path: |
|
args.exp = args.exp + f'_U2U_t{args.t_0}_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_{os.path.split(args.model_path)[-1].replace(".pth", "")}' |
|
elif args.hybrid_noise: |
|
hb_str = '_' |
|
for i, model_name in enumerate(HYBRID_MODEL_PATHS): |
|
hb_str = hb_str + model_name.split('_')[1] |
|
if i != len(HYBRID_MODEL_PATHS) - 1: |
|
hb_str = hb_str + '_' |
|
args.exp = args.exp + f'_U2U_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}_ngen{args.n_train_step}' + hb_str |
|
else: |
|
args.exp = args.exp + f'_U2U_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}_ngen{args.n_train_step}_orig' |
|
|
|
elif args.recon_exp: |
|
args.exp = args.exp + f'_REC_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}' |
|
elif args.find_best_image: |
|
args.exp = args.exp + f'_FOpt_{new_config.data.category}_{args.trg_txts[0]}_t{args.t_0}_ninv{args.n_train_step}' |
|
|
|
|
|
level = getattr(logging, args.verbose.upper(), None) |
|
if not isinstance(level, int): |
|
raise ValueError('level {} not supported'.format(args.verbose)) |
|
|
|
handler1 = logging.StreamHandler() |
|
formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s') |
|
handler1.setFormatter(formatter) |
|
logger = logging.getLogger() |
|
logger.addHandler(handler1) |
|
logger.setLevel(level) |
|
|
|
os.makedirs(args.exp, exist_ok=True) |
|
os.makedirs('checkpoint', exist_ok=True) |
|
os.makedirs('precomputed', exist_ok=True) |
|
os.makedirs('runs', exist_ok=True) |
|
os.makedirs(args.exp, exist_ok=True) |
|
args.image_folder = os.path.join(args.exp, 'image_samples') |
|
if not os.path.exists(args.image_folder): |
|
os.makedirs(args.image_folder) |
|
else: |
|
overwrite = False |
|
if args.ni: |
|
overwrite = True |
|
else: |
|
response = input("Image folder already exists. Overwrite? (Y/N)") |
|
if response.upper() == 'Y': |
|
overwrite = True |
|
|
|
if overwrite: |
|
|
|
os.makedirs(args.image_folder, exist_ok=True) |
|
else: |
|
print("Output image folder exists. Program halted.") |
|
sys.exit(0) |
|
|
|
|
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
logging.info("Using device: {}".format(device)) |
|
new_config.device = device |
|
|
|
|
|
torch.manual_seed(args.seed) |
|
np.random.seed(args.seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed_all(args.seed) |
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
return args, new_config |
|
|
|
|
|
def dict2namespace(config): |
|
namespace = argparse.Namespace() |
|
for key, value in config.items(): |
|
if isinstance(value, dict): |
|
new_value = dict2namespace(value) |
|
else: |
|
new_value = value |
|
setattr(namespace, key, new_value) |
|
return namespace |
|
|
|
|
|
def main(): |
|
args, config = parse_args_and_config() |
|
print(">" * 80) |
|
logging.info("Exp instance id = {}".format(os.getpid())) |
|
logging.info("Exp comment = {}".format(args.comment)) |
|
logging.info("Config =") |
|
print("<" * 80) |
|
|
|
|
|
runner = BoundaryDiffusion(args, config) |
|
try: |
|
if args.clip_finetune: |
|
runner.clip_finetune() |
|
elif args.radius: |
|
runner.radius() |
|
elif args.unconditional: |
|
runner.unconditional() |
|
elif args.diffusion_hyperplane: |
|
runner.diffusion_hyperplane() |
|
elif args.boundary_search: |
|
runner.boundary_search() |
|
elif args.edit_image_boundary: |
|
runner.edit_image_boundary() |
|
else: |
|
print('Choose one mode!') |
|
raise ValueError |
|
except Exception: |
|
logging.error(traceback.format_exc()) |
|
|
|
|
|
return 0 |
|
|
|
|
|
if __name__ == '__main__': |
|
sys.exit(main()) |
|
|