# Copyright 2023 Adobe Research. All rights reserved. # To view a copy of the license, visit LICENSE.md. # import os # CACHE_DIR = "/exp/domain-expansion/.cache" # os.environ["HF_DATASETS_CACHE"] = CACHE_DIR # if not os.path.exists(CACHE_DIR): # os.mkdir(CACHE_DIR) import torch import gradio as gr from generate_aligned import generate_images def main(): from huggingface_hub import hf_hub_download dog_path = hf_hub_download("alvanlii/adobe-domain-expansion", "afhq50.pkl") human_path = hf_hub_download("alvanlii/adobe-domain-expansion", "ffhq100.pkl") def gen_img(fn_seed, fn_is_dog): torch.manual_seed(fn_seed) # print(fn_is_dog) with torch.no_grad(): imgs = generate_images(dog_path if fn_is_dog else human_path, 2, 1) return imgs def load_examples(): torch.manual_seed(32) # print(fn_is_dog) with torch.no_grad(): imgs = generate_images(dog_path, 2, 1) return 32, 1, imgs[0], imgs[1] with gr.Blocks() as demo: gr.HTML("""

Domain Expansion of Image Generators (https://arxiv.org/abs/2301.05225)

Yotam Nitzan, Michaƫl Gharbi, Richard Zhang, Taesung Park, Jun-Yan Zhu, Daniel Cohen-Or, Eli Shechtman
Using the pretrained weights for Humans and Dog faces to generate images in new domains. Only a quarter of the new domains are showcased due to large number of images generated """) with gr.Row(): seed = gr.Number(value=42, precision=1, label="Seed", interactive=True) is_dog = gr.Radio( ["Humans", "Doggos"], value="Doggos", type="index", show_label=False, interactive=True ) generate_button = gr.Button("Generate") sample_button = gr.Button("Load Example") with gr.Row(): g1 = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery1" ).style(grid=[10], height="auto") with gr.Row(): g2 = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery2" ).style(grid=[10], height="auto") generate_button.click( fn=gen_img, inputs=[ seed, is_dog ], outputs=[g1, g2] ) sample_button.click( fn=load_examples, inputs=[], outputs=[seed, is_dog, g1, g2] ) demo.queue(concurrency_count=1) demo.launch(share=False, server_name="0.0.0.0") if __name__ == "__main__": main()