Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import random | |
| from huggingface_hub import hf_hub_download | |
| import spaces # [uncomment to use ZeroGPU] | |
| from diffusers import FluxPipeline | |
| import torch | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_repo_id = "black-forest-labs/FLUX.1-schnell" # Replace to the model you would like to use | |
| torch_dtype = torch.bfloat16 | |
| # pipe = FluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype) | |
| # pipe = pipe.to(device) | |
| # load pruned model | |
| pruned_pipe = FluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype) | |
| pruned_pipe.transformer = torch.load( | |
| hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/flux/flux.pkl"), | |
| map_location="cpu", | |
| ) | |
| pruned_pipe = pruned_pipe.to(device) | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 1024 | |
| def generate_images(prompt, seed, steps): | |
| # Run the model and return images directly | |
| # g_cpu = torch.Generator("cuda").manual_seed(seed) | |
| # original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0] | |
| g_cpu = torch.Generator("cuda").manual_seed(seed) | |
| ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0] | |
| return ecodiff_image | |
| examples = [ | |
| "A clock tower floating in a sea of clouds", | |
| "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
| "An astronaut riding a green horse", | |
| "A delicious ceviche cheesecake slice", | |
| "A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages", | |
| ] | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 640px; | |
| } | |
| """ | |
| header = """ | |
| # π± EcoDiff Pruned FLUX-Schnell (20% Pruning Ratio) | |
| We are not able to host two FLUX models in the same space, so we only show the pruned model here. **π [Click here to compare with the Original FLUX Model](https://huggingface.co/spaces/black-forest-labs/FLUX.1-schnell)**. | |
| """ | |
| header_2 = """ | |
| <div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> | |
| <a href="https://arxiv.org/abs/2412.02852"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a> | |
| <a href="https://huggingface.co/zhangyang-0123/EcoDiffPrunedModels"><img src="https://img.shields.io/badge/π€-Model-ffbd45.svg" alt="HuggingFace"></a> | |
| </div> | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown(header) | |
| gr.HTML(header_2) | |
| with gr.Row(): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| value="A clock tower floating in a sea of clouds", | |
| scale=3, | |
| ) | |
| seed = gr.Number(label="Seed", value=44, precision=0, scale=1) | |
| steps = gr.Slider( | |
| label="Number of Steps", | |
| minimum=1, | |
| maximum=100, | |
| value=5, | |
| step=1, | |
| scale=1, | |
| ) | |
| generate_btn = gr.Button("Generate Images") | |
| gr.Examples( | |
| examples=[ | |
| "A clock tower floating in a sea of clouds", | |
| "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
| "An astronaut riding a green horse", | |
| "A delicious ceviche cheesecake slice", | |
| "A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages", | |
| ], | |
| inputs=[prompt], | |
| ) | |
| with gr.Row(): | |
| # original_output = gr.Image(label="Original Output") | |
| ecodiff_output = gr.Image(label="EcoDiff Output") | |
| gr.on( | |
| triggers=[generate_btn.click, prompt.submit], | |
| fn=generate_images, | |
| inputs=[ | |
| prompt, | |
| seed, | |
| steps, | |
| ], | |
| outputs=[ecodiff_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |