|
import os |
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument('--mode', type=str, required=True, choices=['pdg', 'ft', 'both'], |
|
help="pdg: Pose-aware dataset generation, ft: Fine-tuning 3D generative models, both: Doing both") |
|
parser.add_argument('--down_src_eg3d_from_nvidia', default=True) |
|
|
|
parser.add_argument('--pdg_prompt', type=str, required=True) |
|
parser.add_argument('--pdg_generator_type', default='ffhq', type=str, choices=['ffhq', 'cat']) |
|
parser.add_argument('--pdg_strength', default=0.7, type=float) |
|
parser.add_argument('--pdg_guidance_scale', default=8, type=float) |
|
parser.add_argument('--pdg_num_images', default=1000, type=int) |
|
parser.add_argument('--pdg_sd_model_id', default='stabilityai/stable-diffusion-2-1-base', type=str) |
|
parser.add_argument('--pdg_num_inference_steps', default=50, type=int) |
|
parser.add_argument('--pdg_name_tag', default='', type=str) |
|
parser.add_argument('--down_src_eg3d_from_nvidia', default=True) |
|
|
|
parser.add_argument('--ft_generator_type', default='same', help="None: The same type as pdg_generator_type", type=str, choices=['ffhq', 'cat', 'same']) |
|
parser.add_argument('--ft_kimg', default=200, type=int) |
|
parser.add_argument('--ft_batch', default=20, type=int) |
|
parser.add_argument('--ft_tick', default=1, type=int) |
|
parser.add_argument('--ft_snap', default=50, type=int) |
|
parser.add_argument('--ft_outdir', default='../training_runs', type=str) |
|
parser.add_argument('--ft_gpus', default=1, type=str) |
|
parser.add_argument('--ft_workers', default=8, type=int) |
|
parser.add_argument('--ft_data_max_size', default=500000000, type=int) |
|
parser.add_argument('--ft_freeze_dec_sr', default=True, type=bool) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
if args.mode in ['pdg', 'both']: |
|
os.chdir('eg3d') |
|
if args.pdg_generator_type == 'cat': |
|
pdg_generator_id = 'afhqcats512-128.pkl' |
|
else: |
|
pdg_generator_id = 'ffhqrebalanced512-128.pkl' |
|
|
|
pdg_generator_path = f'pretrained/{pdg_generator_id}' |
|
if not os.path.exists(pdg_generator_path): |
|
os.makedirs(f'pretrained', exist_ok=True) |
|
print("Pretrained EG3D model cannot be found. Downloading the pretrained EG3D models.") |
|
if args.down_src_eg3d_from_nvidia == True: |
|
os.system(f'wget -c https://api.ngc.nvidia.com/v2/models/nvidia/research/eg3d/versions/1/files/{pdg_generator_id} -O {pdg_generator_path}') |
|
else: |
|
os.system(f'wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/nvidia_{pdg_generator_id} -O {pdg_generator_path}') |
|
command = f"""python datid3d_data_gen.py \ |
|
--prompt="{args.pdg_prompt}" \ |
|
--data_type={args.pdg_generator_type} \ |
|
--strength={args.pdg_strength} \ |
|
--guidance_scale={args.pdg_guidance_scale} \ |
|
--num_images={args.pdg_num_images} \ |
|
--sd_model_id="{args.pdg_sd_model_id}" \ |
|
--num_inference_steps={args.pdg_num_inference_steps} \ |
|
--name_tag={args.pdg_name_tag} |
|
""" |
|
print(f"{command} \n") |
|
os.system(command) |
|
os.chdir('..') |
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.mode in ['ft', 'both']: |
|
os.chdir('eg3d') |
|
if args.ft_generator_type == 'same': |
|
args.ft_generator_type = args.pdg_generator_type |
|
|
|
if args.ft_generator_type == 'cat': |
|
ft_generator_id = 'afhqcats512-128.pkl' |
|
else: |
|
ft_generator_id = 'ffhqrebalanced512-128.pkl' |
|
|
|
ft_generator_path = f'pretrained/{ft_generator_id}' |
|
if not os.path.exists(ft_generator_path): |
|
os.makedirs(f'pretrained', exist_ok=True) |
|
print("Pretrained EG3D model cannot be found. Downloading the pretrained EG3D models.") |
|
if args.down_src_eg3d_from_nvidia == True: |
|
os.system(f'wget -c https://api.ngc.nvidia.com/v2/models/nvidia/research/eg3d/versions/1/files/{ft_generator_id} -O {ft_generator_path}') |
|
else: |
|
os.system(f'wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/nvidia_{ft_generator_id} -O {ft_generator_path}') |
|
|
|
dataset_id = f'data_{args.pdg_generator_type}_{args.pdg_prompt.replace(" ", "_")}{args.pdg_name_tag}' |
|
dataset_path = f'./exp_data/{dataset_id}/{dataset_id}.zip' |
|
|
|
|
|
command = f"""python train.py \ |
|
--outdir={args.ft_outdir} \ |
|
--cfg={args.ft_generator_type} \ |
|
--data="{dataset_path}" \ |
|
--resume={ft_generator_path} --freeze_dec_sr={args.ft_freeze_dec_sr} \ |
|
--batch={args.ft_batch} --workers={args.ft_workers} --gpus={args.ft_gpus} \ |
|
--tick={args.ft_tick} --snap={args.ft_snap} --data_max_size={args.ft_data_max_size} --kimg={args.ft_kimg} \ |
|
--gamma=5 --aug=ada --neural_rendering_resolution_final=128 --gen_pose_cond=True --gpc_reg_prob=0.8 --metrics=None |
|
""" |
|
print(f"{command} \n") |
|
os.system(command) |
|
os.chdir('..') |
|
|