|
import argparse |
|
import gradio as gr |
|
import os |
|
import shutil |
|
from glob import glob |
|
from PIL import Image |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from torchvision.utils import make_grid, save_image |
|
from torchvision.io import read_image |
|
import torchvision.transforms.functional as F |
|
from functools import partial |
|
from datetime import datetime |
|
|
|
|
|
plt.rcParams["savefig.bbox"] = 'tight' |
|
|
|
def show(imgs): |
|
if not isinstance(imgs, list): |
|
imgs = [imgs] |
|
fig, axs = plt.subplots(ncols=len(imgs), squeeze=False) |
|
for i, img in enumerate(imgs): |
|
img = F.to_pil_image(img.detach()) |
|
axs[0, i].imshow(np.asarray(img)) |
|
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) |
|
|
|
class Intermediate: |
|
def __init__(self): |
|
self.input_img = None |
|
self.input_img_time = 0 |
|
|
|
|
|
model_ckpts = {"elf": "ffhq-elf.pkl", |
|
"greek_statue": "ffhq-greek_statue.pkl", |
|
"hobbit": "ffhq-hobbit.pkl", |
|
"lego": "ffhq-lego.pkl", |
|
"masquerade": "ffhq-masquerade.pkl", |
|
"neanderthal": "ffhq-neanderthal.pkl", |
|
"orc": "ffhq-orc.pkl", |
|
"pixar": "ffhq-pixar.pkl", |
|
"skeleton": "ffhq-skeleton.pkl", |
|
"stone_golem": "ffhq-stone_golem.pkl", |
|
"super_mario": "ffhq-super_mario.pkl", |
|
"tekken": "ffhq-tekken.pkl", |
|
"yoda": "ffhq-yoda.pkl", |
|
"zombie": "ffhq-zombie.pkl", |
|
"cat_in_Zootopia": "cat-cat_in_Zootopia.pkl", |
|
"fox_in_Zootopia": "cat-fox_in_Zootopia.pkl", |
|
"golden_aluminum_animal": "cat-golden_aluminum_animal.pkl", |
|
} |
|
|
|
manip_model_ckpts = {"super_mario": "ffhq-super_mario.pkl", |
|
"lego": "ffhq-lego.pkl", |
|
"neanderthal": "ffhq-neanderthal.pkl", |
|
"orc": "ffhq-orc.pkl", |
|
"pixar": "ffhq-pixar.pkl", |
|
"skeleton": "ffhq-skeleton.pkl", |
|
"stone_golem": "ffhq-stone_golem.pkl", |
|
"tekken": "ffhq-tekken.pkl", |
|
"greek_statue": "ffhq-greek_statue.pkl", |
|
"yoda": "ffhq-yoda.pkl", |
|
"zombie": "ffhq-zombie.pkl", |
|
"elf": "ffhq-elf.pkl", |
|
} |
|
|
|
|
|
def TextGuidedImageTo3D(intermediate, img, model_name, num_inversion_steps, truncation): |
|
if img != intermediate.input_img: |
|
if os.path.exists('input_imgs_gradio'): |
|
shutil.rmtree('input_imgs_gradio') |
|
os.makedirs('input_imgs_gradio', exist_ok=True) |
|
img.save('input_imgs_gradio/input.png') |
|
intermediate.input_img = img |
|
now = datetime.now() |
|
intermediate.input_img_time = now.strftime('%Y-%m-%d_%H:%M:%S') |
|
|
|
all_model_names = manip_model_ckpts.keys() |
|
generator_type = 'ffhq' |
|
|
|
if model_name == 'all': |
|
_no_video_models = [] |
|
for _model_name in all_model_names: |
|
if not os.path.exists(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/finetuned___{model_ckpts[_model_name]}__input_inv.mp4'): |
|
print() |
|
_no_video_models.append(_model_name) |
|
|
|
model_names_command = '' |
|
for _model_name in _no_video_models: |
|
if not os.path.exists(f'finetuned/{model_ckpts[_model_name]}'): |
|
command = f"""wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/{model_ckpts[_model_name]} -O finetuned/{model_ckpts[_model_name]} |
|
""" |
|
os.system(command) |
|
|
|
model_names_command += f"finetuned/{model_ckpts[_model_name]} " |
|
|
|
w_pths = sorted(glob(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/3_inversion_result/*.pt')) |
|
if len(w_pths) == 0: |
|
mode = 'manip' |
|
else: |
|
mode = 'manip_from_inv' |
|
|
|
if len(_no_video_models) > 0: |
|
command = f"""python datid3d_test.py --mode {mode} \ |
|
--indir='input_imgs_gradio' \ |
|
--generator_type={generator_type} \ |
|
--outdir='test_runs' \ |
|
--trunc={truncation} \ |
|
--network {model_names_command} \ |
|
--num_inv_steps={num_inversion_steps} \ |
|
--down_src_eg3d_from_nvidia=False \ |
|
--name_tag='_gradio_{intermediate.input_img_time}' \ |
|
--shape=False \ |
|
--w_frames 60 |
|
""" |
|
print(command) |
|
os.system(command) |
|
|
|
aligned_img_pth = sorted(glob(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/2_pose_result/*.png'))[0] |
|
aligned_img = Image.open(aligned_img_pth) |
|
|
|
result_imgs = [] |
|
for _model_name in all_model_names: |
|
img_pth = f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/finetuned___{model_ckpts[_model_name]}__input_inv.png' |
|
result_imgs.append(read_image(img_pth)) |
|
|
|
result_grid_pt = make_grid(result_imgs, nrow=1) |
|
result_img = F.to_pil_image(result_grid_pt) |
|
else: |
|
if not os.path.exists(f'finetuned/{model_ckpts[model_name]}'): |
|
command = f"""wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/{model_ckpts[model_name]} -O finetuned/{model_ckpts[model_name]} |
|
""" |
|
os.system(command) |
|
|
|
if not os.path.exists(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/finetuned___{model_ckpts[model_name]}__input_inv.mp4'): |
|
w_pths = sorted(glob(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/3_inversion_result/*.pt')) |
|
if len(w_pths) == 0: |
|
mode = 'manip' |
|
else: |
|
mode = 'manip_from_inv' |
|
|
|
command = f"""python datid3d_test.py --mode {mode} \ |
|
--indir='input_imgs_gradio' \ |
|
--generator_type={generator_type} \ |
|
--outdir='test_runs' \ |
|
--trunc={truncation} \ |
|
--network finetuned/{model_ckpts[model_name]} \ |
|
--num_inv_steps={num_inversion_steps} \ |
|
--down_src_eg3d_from_nvidia=0 \ |
|
--name_tag='_gradio_{intermediate.input_img_time}' \ |
|
--shape=False |
|
--w_frames 60""" |
|
print(command) |
|
os.system(command) |
|
|
|
aligned_img_pth = sorted(glob(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/2_pose_result/*.png'))[0] |
|
aligned_img = Image.open(aligned_img_pth) |
|
|
|
result_img_pth = sorted(glob(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/*{model_ckpts[model_name]}*.png'))[0] |
|
result_img = Image.open(result_img_pth) |
|
|
|
|
|
|
|
|
|
if model_name=='all': |
|
result_video_pth = f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/finetuned___ffhq-all__input_inv.mp4' |
|
if os.path.exists(result_video_pth): |
|
os.remove(result_video_pth) |
|
command = 'ffmpeg ' |
|
for _model_name in all_model_names: |
|
command += f'-i test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/finetuned___ffhq-{_model_name}.pkl__input_inv.mp4 ' |
|
|
|
command += '-filter_complex "[v0][v1][v2][v3][v4][v5][v6][v7][v8][v9][v10][v11]concat=n=12:v=1:a=0[output]"' |
|
command += f" -vcodec libx264 {result_video_pth}" |
|
print() |
|
print(command) |
|
os.system(command) |
|
|
|
else: |
|
result_video_pth = sorted(glob(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/*{model_ckpts[model_name]}*.mp4'))[0] |
|
|
|
return aligned_img, result_img, result_video_pth |
|
|
|
|
|
def SampleImage(model_name, num_samples, truncation, seed): |
|
seed_list = np.random.RandomState(seed).choice(np.arange(10000), num_samples).tolist() |
|
seeds = '' |
|
for seed in seed_list: |
|
seeds += f'{seed},' |
|
seeds = seeds[:-1] |
|
|
|
if model_name in ["fox_in_Zootopia", "cat_in_Zootopia", "golden_aluminum_animal"]: |
|
generator_type = 'cat' |
|
else: |
|
generator_type = 'ffhq' |
|
|
|
if not os.path.exists(f'finetuned/{model_ckpts[model_name]}'): |
|
command = f"""wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/{model_ckpts[model_name]} -O finetuned/{model_ckpts[model_name]} |
|
""" |
|
os.system(command) |
|
|
|
command = f"""python datid3d_test.py --mode image \ |
|
--generator_type={generator_type} \ |
|
--outdir='test_runs' \ |
|
--seeds={seeds} \ |
|
--trunc={truncation} \ |
|
--network=finetuned/{model_ckpts[model_name]} \ |
|
--shape=False""" |
|
print(command) |
|
os.system(command) |
|
|
|
result_img_pths = sorted(glob(f'test_runs/image/*{model_ckpts[model_name]}*.png')) |
|
result_imgs = [] |
|
for img_pth in result_img_pths: |
|
result_imgs.append(read_image(img_pth)) |
|
|
|
result_grid_pt = make_grid(result_imgs, nrow=1) |
|
result_grid_pil = F.to_pil_image(result_grid_pt) |
|
return result_grid_pil |
|
|
|
|
|
|
|
|
|
def SampleVideo(model_name, grid_height, truncation, seed): |
|
seed_list = np.random.RandomState(seed).choice(np.arange(10000), grid_height**2).tolist() |
|
seeds = '' |
|
for seed in seed_list: |
|
seeds += f'{seed},' |
|
seeds = seeds[:-1] |
|
|
|
if model_name in ["fox_in_Zootopia", "cat_in_Zootopia", "golden_aluminum_animal"]: |
|
generator_type = 'cat' |
|
else: |
|
generator_type = 'ffhq' |
|
|
|
if not os.path.exists(f'finetuned/{model_ckpts[model_name]}'): |
|
command = f"""wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/{model_ckpts[model_name]} -O finetuned/{model_ckpts[model_name]} |
|
""" |
|
os.system(command) |
|
|
|
command = f"""python datid3d_test.py --mode video \ |
|
--generator_type={generator_type} \ |
|
--outdir='test_runs' \ |
|
--seeds={seeds} \ |
|
--trunc={truncation} \ |
|
--grid={grid_height}x{grid_height} \ |
|
--network=finetuned/{model_ckpts[model_name]} \ |
|
--shape=False""" |
|
print(command) |
|
os.system(command) |
|
|
|
result_video_pth = sorted(glob(f'test_runs/video/*{model_ckpts[model_name]}*.mp4'))[0] |
|
|
|
return result_video_pth |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--share', action='store_true', help="public url") |
|
args = parser.parse_args() |
|
|
|
demo = gr.Blocks(title="DATID-3D Interactive Demo") |
|
os.makedirs('finetuned', exist_ok=True) |
|
intermediate = Intermediate() |
|
with demo: |
|
gr.Markdown("# DATID-3D Interactive Demo") |
|
gr.Markdown( |
|
"### Demo of the CVPR 2023 paper \"DATID-3D: Diversity-Preserved Domain Adaptation Using Text-to-Image Diffusion for 3D Generative Model\"") |
|
|
|
with gr.Tab("Text-guided Manipulated 3D reconstruction"): |
|
gr.Markdown("Text-guided Image-to-3D Translation") |
|
with gr.Row(): |
|
with gr.Column(scale=1, variant='panel'): |
|
t_image_input = gr.Image(source='upload', type="pil", interactive=True) |
|
|
|
t_model_name = gr.Radio(["super_mario", "lego", "neanderthal", "orc", |
|
"pixar", "skeleton", "stone_golem","tekken", |
|
"greek_statue", "yoda", "zombie", "elf", "all"], |
|
label="Model fine-tuned through DATID-3D", |
|
value="super_mario", interactive=True) |
|
with gr.Accordion("Advanced Options", open=False): |
|
t_truncation = gr.Slider(label="Truncation psi", minimum=0, maximum=1.0, step=0.01, randomize=False, value=0.8) |
|
t_num_inversion_steps = gr.Slider(200, 1000, value=200, step=1, label='Number of steps for the invresion') |
|
with gr.Row(): |
|
t_button_gen_result = gr.Button("Generate Result", variant='primary') |
|
|
|
|
|
with gr.Row(): |
|
t_align_image_result = gr.Image(label="Alignment result", interactive=False) |
|
with gr.Column(scale=1, variant='panel'): |
|
with gr.Row(): |
|
t_video_result = gr.Video(label="Video result", interactive=False) |
|
|
|
with gr.Row(): |
|
t_image_result = gr.Image(label="Image result", interactive=False) |
|
|
|
|
|
with gr.Tab("Sample Images"): |
|
with gr.Row(): |
|
with gr.Column(scale=1, variant='panel'): |
|
i_model_name = gr.Radio( |
|
["elf", "greek_statue", "hobbit", "lego", "masquerade", "neanderthal", "orc", "pixar", |
|
"skeleton", "stone_golem", "super_mario", "tekken", "yoda", "zombie", "fox_in_Zootopia", |
|
"cat_in_Zootopia", "golden_aluminum_animal"], |
|
label="Model fine-tuned through DATID-3D", |
|
value="super_mario", interactive=True) |
|
i_num_samples = gr.Slider(0, 20, value=4, step=1, label='Number of samples') |
|
i_seed = gr.Slider(label="Seed", minimum=0, maximum=1000000000, step=1, value=1235) |
|
with gr.Accordion("Advanced Options", open=False): |
|
i_truncation = gr.Slider(label="Truncation psi", minimum=0, maximum=1.0, step=0.01, randomize=False, value=0.8) |
|
with gr.Row(): |
|
i_button_gen_image = gr.Button("Generate Image", variant='primary') |
|
with gr.Column(scale=1, variant='panel'): |
|
with gr.Row(): |
|
i_image_result = gr.Image(label="Image result", interactive=False) |
|
|
|
|
|
with gr.Tab("Sample Videos"): |
|
with gr.Row(): |
|
with gr.Column(scale=1, variant='panel'): |
|
v_model_name = gr.Radio( |
|
["elf", "greek_statue", "hobbit", "lego", "masquerade", "neanderthal", "orc", "pixar", |
|
"skeleton", "stone_golem", "super_mario", "tekken", "yoda", "zombie", "fox_in_Zootopia", |
|
"cat_in_Zootopia", "golden_aluminum_animal"], |
|
label="Model fine-tuned through DATID-3D", |
|
value="super_mario", interactive=True) |
|
v_grid_height = gr.Slider(0, 5, value=2, step=1,label='Height of the grid') |
|
v_seed = gr.Slider(label="Seed", minimum=0, maximum=1000000000, step=1, value=1235) |
|
with gr.Accordion("Advanced Options", open=False): |
|
v_truncation = gr.Slider(label="Truncation psi", minimum=0, maximum=1.0, step=0.01, randomize=False, |
|
value=0.8) |
|
|
|
with gr.Row(): |
|
v_button_gen_video = gr.Button("Generate Video", variant='primary') |
|
|
|
with gr.Column(scale=1, variant='panel'): |
|
|
|
with gr.Row(): |
|
v_video_result = gr.Video(label="Video result", interactive=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
t_button_gen_result.click(fn=partial(TextGuidedImageTo3D, intermediate), |
|
inputs=[t_image_input, t_model_name, t_num_inversion_steps, t_truncation], |
|
outputs=[t_align_image_result, t_image_result, t_video_result]) |
|
i_button_gen_image.click(fn=SampleImage, |
|
inputs=[i_model_name, i_num_samples, i_truncation, i_seed], |
|
outputs=[i_image_result]) |
|
v_button_gen_video.click(fn=SampleVideo, |
|
inputs=[i_model_name, v_grid_height, v_truncation, v_seed], |
|
outputs=[v_video_result]) |
|
|
|
demo.queue(concurrency_count=1) |
|
demo.launch(share=args.share) |
|
|
|
|