Spaces:
Runtime error
Runtime error
File size: 2,763 Bytes
560a1b9 18d9bb8 560a1b9 cf2717f 560a1b9 cf2717f 560a1b9 cf2717f 10e5833 cf2717f 560a1b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
# Copyright 2023 Adobe Research. All rights reserved.
# To view a copy of the license, visit LICENSE.md.
# import os
# CACHE_DIR = "/exp/domain-expansion/.cache"
# os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
# if not os.path.exists(CACHE_DIR):
# os.mkdir(CACHE_DIR)
import torch
import gradio as gr
from generate_aligned import generate_images
def main():
from huggingface_hub import hf_hub_download
dog_path = hf_hub_download("alvanlii/adobe-domain-expansion", "afhq50.pkl")
human_path = hf_hub_download("alvanlii/adobe-domain-expansion", "ffhq100.pkl")
def gen_img(fn_seed, fn_is_dog):
torch.manual_seed(fn_seed)
# print(fn_is_dog)
with torch.no_grad():
imgs = generate_images(dog_path if fn_is_dog else human_path, 2, 1)
return imgs
def load_examples():
torch.manual_seed(32)
# print(fn_is_dog)
with torch.no_grad():
imgs = generate_images(dog_path, 2, 1)
return 32, 1, imgs[0], imgs[1]
with gr.Blocks() as demo:
gr.HTML("""
<h1 style="font-weight: 900; margin-bottom: 7px;">
Domain Expansion of Image Generators (https://arxiv.org/abs/2301.05225)
</h1>
Yotam Nitzan, Michaël Gharbi, Richard Zhang, Taesung Park, Jun-Yan Zhu, Daniel Cohen-Or, Eli Shechtman <br/>
Using the pretrained weights for Humans and Dog faces to generate images in new domains. Only a quarter of the new domains are showcased due to large number of images generated
""")
with gr.Row():
seed = gr.Number(value=42, precision=1, label="Seed", interactive=True)
is_dog = gr.Radio(
["Humans", "Doggos"],
value="Doggos",
type="index",
show_label=False,
interactive=True
)
generate_button = gr.Button("Generate")
sample_button = gr.Button("Load Example")
with gr.Row():
g1 = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery1"
).style(grid=[10], height="auto")
with gr.Row():
g2 = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery2"
).style(grid=[10], height="auto")
generate_button.click(
fn=gen_img,
inputs=[
seed, is_dog
],
outputs=[g1, g2]
)
sample_button.click(
fn=load_examples,
inputs=[],
outputs=[seed, is_dog, g1, g2]
)
demo.queue(concurrency_count=1)
demo.launch(share=False, server_name="0.0.0.0")
if __name__ == "__main__":
main()
|