File size: 4,950 Bytes
f12ab4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import os
import argparse
### Parameters
parser = argparse.ArgumentParser()
# For all
parser.add_argument('--mode', type=str, required=True, choices=['pdg', 'ft', 'both'],
help="pdg: Pose-aware dataset generation, ft: Fine-tuning 3D generative models, both: Doing both")
parser.add_argument('--down_src_eg3d_from_nvidia', default=True)
# Pose-aware dataset generation
parser.add_argument('--pdg_prompt', type=str, required=True)
parser.add_argument('--pdg_generator_type', default='ffhq', type=str, choices=['ffhq', 'cat']) # ffhq, cat
parser.add_argument('--pdg_strength', default=0.7, type=float)
parser.add_argument('--pdg_guidance_scale', default=8, type=float)
parser.add_argument('--pdg_num_images', default=1000, type=int)
parser.add_argument('--pdg_sd_model_id', default='stabilityai/stable-diffusion-2-1-base', type=str)
parser.add_argument('--pdg_num_inference_steps', default=50, type=int)
parser.add_argument('--pdg_name_tag', default='', type=str)
parser.add_argument('--down_src_eg3d_from_nvidia', default=True)
# Fine-tuning 3D generative models
parser.add_argument('--ft_generator_type', default='same', help="None: The same type as pdg_generator_type", type=str, choices=['ffhq', 'cat', 'same'])
parser.add_argument('--ft_kimg', default=200, type=int)
parser.add_argument('--ft_batch', default=20, type=int)
parser.add_argument('--ft_tick', default=1, type=int)
parser.add_argument('--ft_snap', default=50, type=int)
parser.add_argument('--ft_outdir', default='../training_runs', type=str) #
parser.add_argument('--ft_gpus', default=1, type=str) #
parser.add_argument('--ft_workers', default=8, type=int) #
parser.add_argument('--ft_data_max_size', default=500000000, type=int) #
parser.add_argument('--ft_freeze_dec_sr', default=True, type=bool) #
args = parser.parse_args()
### Pose-aware target generation
if args.mode in ['pdg', 'both']:
os.chdir('eg3d')
if args.pdg_generator_type == 'cat':
pdg_generator_id = 'afhqcats512-128.pkl'
else:
pdg_generator_id = 'ffhqrebalanced512-128.pkl'
pdg_generator_path = f'pretrained/{pdg_generator_id}'
if not os.path.exists(pdg_generator_path):
os.makedirs(f'pretrained', exist_ok=True)
print("Pretrained EG3D model cannot be found. Downloading the pretrained EG3D models.")
if args.down_src_eg3d_from_nvidia == True:
os.system(f'wget -c https://api.ngc.nvidia.com/v2/models/nvidia/research/eg3d/versions/1/files/{pdg_generator_id} -O {pdg_generator_path}')
else:
os.system(f'wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/nvidia_{pdg_generator_id} -O {pdg_generator_path}')
command = f"""python datid3d_data_gen.py \
--prompt="{args.pdg_prompt}" \
--data_type={args.pdg_generator_type} \
--strength={args.pdg_strength} \
--guidance_scale={args.pdg_guidance_scale} \
--num_images={args.pdg_num_images} \
--sd_model_id="{args.pdg_sd_model_id}" \
--num_inference_steps={args.pdg_num_inference_steps} \
--name_tag={args.pdg_name_tag}
"""
print(f"{command} \n")
os.system(command)
os.chdir('..')
### Filtering process
# TODO
### Fine-tuning 3D generative models
if args.mode in ['ft', 'both']:
os.chdir('eg3d')
if args.ft_generator_type == 'same':
args.ft_generator_type = args.pdg_generator_type
if args.ft_generator_type == 'cat':
ft_generator_id = 'afhqcats512-128.pkl'
else:
ft_generator_id = 'ffhqrebalanced512-128.pkl'
ft_generator_path = f'pretrained/{ft_generator_id}'
if not os.path.exists(ft_generator_path):
os.makedirs(f'pretrained', exist_ok=True)
print("Pretrained EG3D model cannot be found. Downloading the pretrained EG3D models.")
if args.down_src_eg3d_from_nvidia == True:
os.system(f'wget -c https://api.ngc.nvidia.com/v2/models/nvidia/research/eg3d/versions/1/files/{ft_generator_id} -O {ft_generator_path}')
else:
os.system(f'wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/nvidia_{ft_generator_id} -O {ft_generator_path}')
dataset_id = f'data_{args.pdg_generator_type}_{args.pdg_prompt.replace(" ", "_")}{args.pdg_name_tag}'
dataset_path = f'./exp_data/{dataset_id}/{dataset_id}.zip'
command = f"""python train.py \
--outdir={args.ft_outdir} \
--cfg={args.ft_generator_type} \
--data="{dataset_path}" \
--resume={ft_generator_path} --freeze_dec_sr={args.ft_freeze_dec_sr} \
--batch={args.ft_batch} --workers={args.ft_workers} --gpus={args.ft_gpus} \
--tick={args.ft_tick} --snap={args.ft_snap} --data_max_size={args.ft_data_max_size} --kimg={args.ft_kimg} \
--gamma=5 --aug=ada --neural_rendering_resolution_final=128 --gen_pose_cond=True --gpc_reg_prob=0.8 --metrics=None
"""
print(f"{command} \n")
os.system(command)
os.chdir('..')
|