alvanli
Fix typo
10e5833
# Copyright 2023 Adobe Research. All rights reserved.
# To view a copy of the license, visit LICENSE.md.
# import os
# CACHE_DIR = "/exp/domain-expansion/.cache"
# os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
# if not os.path.exists(CACHE_DIR):
# os.mkdir(CACHE_DIR)
import torch
import gradio as gr
from generate_aligned import generate_images
def main():
from huggingface_hub import hf_hub_download
dog_path = hf_hub_download("alvanlii/adobe-domain-expansion", "afhq50.pkl")
human_path = hf_hub_download("alvanlii/adobe-domain-expansion", "ffhq100.pkl")
def gen_img(fn_seed, fn_is_dog):
torch.manual_seed(fn_seed)
# print(fn_is_dog)
with torch.no_grad():
imgs = generate_images(dog_path if fn_is_dog else human_path, 2, 1)
return imgs
def load_examples():
torch.manual_seed(32)
# print(fn_is_dog)
with torch.no_grad():
imgs = generate_images(dog_path, 2, 1)
return 32, 1, imgs[0], imgs[1]
with gr.Blocks() as demo:
gr.HTML("""
<h1 style="font-weight: 900; margin-bottom: 7px;">
Domain Expansion of Image Generators (https://arxiv.org/abs/2301.05225)
</h1>
Yotam Nitzan, Michaël Gharbi, Richard Zhang, Taesung Park, Jun-Yan Zhu, Daniel Cohen-Or, Eli Shechtman <br/>
Using the pretrained weights for Humans and Dog faces to generate images in new domains. Only a quarter of the new domains are showcased due to large number of images generated
""")
with gr.Row():
seed = gr.Number(value=42, precision=1, label="Seed", interactive=True)
is_dog = gr.Radio(
["Humans", "Doggos"],
value="Doggos",
type="index",
show_label=False,
interactive=True
)
generate_button = gr.Button("Generate")
sample_button = gr.Button("Load Example")
with gr.Row():
g1 = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery1"
).style(grid=[10], height="auto")
with gr.Row():
g2 = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery2"
).style(grid=[10], height="auto")
generate_button.click(
fn=gen_img,
inputs=[
seed, is_dog
],
outputs=[g1, g2]
)
sample_button.click(
fn=load_examples,
inputs=[],
outputs=[seed, is_dog, g1, g2]
)
demo.queue(concurrency_count=1)
demo.launch(share=False, server_name="0.0.0.0")
if __name__ == "__main__":
main()