File size: 2,763 Bytes
560a1b9
 
18d9bb8
 
 
 
 
560a1b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf2717f
 
 
 
 
 
 
560a1b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf2717f
560a1b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf2717f
 
 
10e5833
cf2717f
 
560a1b9
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Copyright 2023 Adobe Research. All rights reserved.
# To view a copy of the license, visit LICENSE.md.
# import os
# CACHE_DIR = "/exp/domain-expansion/.cache"
# os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
# if not os.path.exists(CACHE_DIR):
#     os.mkdir(CACHE_DIR)

import torch

import gradio as gr
from generate_aligned import generate_images

def main():
    from huggingface_hub import hf_hub_download
    dog_path = hf_hub_download("alvanlii/adobe-domain-expansion", "afhq50.pkl")
    human_path = hf_hub_download("alvanlii/adobe-domain-expansion", "ffhq100.pkl")

    def gen_img(fn_seed, fn_is_dog):
        torch.manual_seed(fn_seed)
        # print(fn_is_dog)
        with torch.no_grad():
            imgs = generate_images(dog_path if fn_is_dog else human_path, 2, 1)
        return imgs

    def load_examples():
        torch.manual_seed(32)
        # print(fn_is_dog)
        with torch.no_grad():
            imgs = generate_images(dog_path, 2, 1)
        return 32, 1, imgs[0], imgs[1]
    
    with gr.Blocks() as demo:
        gr.HTML("""
            <h1 style="font-weight: 900; margin-bottom: 7px;">
            Domain Expansion of Image Generators (https://arxiv.org/abs/2301.05225) 
            </h1> 
            Yotam Nitzan, Michaël Gharbi, Richard Zhang, Taesung Park, Jun-Yan Zhu, Daniel Cohen-Or, Eli Shechtman  <br/>
            Using the pretrained weights for Humans and Dog faces to generate images in new domains. Only a quarter of the new domains are showcased due to large number of images generated
        """)
        with gr.Row():
            seed = gr.Number(value=42, precision=1, label="Seed", interactive=True)
            is_dog = gr.Radio(
                ["Humans", "Doggos"],
                value="Doggos",
                type="index",
                show_label=False,
                interactive=True
            )
            generate_button = gr.Button("Generate")
            sample_button = gr.Button("Load Example")

        with gr.Row():
            g1 = gr.Gallery(
                label="Generated images", show_label=False, elem_id="gallery1"
            ).style(grid=[10], height="auto")
        with gr.Row():
            g2 = gr.Gallery(
                label="Generated images", show_label=False, elem_id="gallery2"
            ).style(grid=[10], height="auto")
    
        generate_button.click(
            fn=gen_img,
            inputs=[
                seed, is_dog
            ],
            outputs=[g1, g2]
        )

        sample_button.click(
            fn=load_examples,
            inputs=[],
            outputs=[seed, is_dog, g1, g2]
        )

    demo.queue(concurrency_count=1)
    demo.launch(share=False, server_name="0.0.0.0")


if __name__ == "__main__":
    main()