Spaces:
Running
Running
File size: 2,388 Bytes
ef187eb e7915f0 0cffd40 ef187eb 0cffd40 8b1e96d 0cffd40 8b1e96d 0cccf69 8b1e96d ec35e66 8b1e96d e7915f0 8b1e96d f286ae5 8b1e96d 3494613 6380dba 8b1e96d 3819ced 1e00cbb fee8445 0cffd40 ef187eb 8b1e96d 0cffd40 8b1e96d 0cffd40 556fb50 8b1e96d 3eaeeea 8b1e96d 0cffd40 8b1e96d e7915f0 8b1e96d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import gradio as gr
import torch
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
from huggingface_hub import hf_hub_download
import spaces
from PIL import Image
# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "tianweiy/DMD2"
checkpoints = {
"1-Step" : ["dmd2_sdxl_1step_unet_fp16.bin", 1],
"4-Step" : ["dmd2_sdxl_4step_unet_fp16.bin", 4],
}
loaded = None
CSS = """
.gradio-container {
max-width: 690px !important;
}
"""
# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
pipe = DiffusionPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
# Function
@spaces.GPU()
def generate_image(prompt, ckpt):
global loaded
print(prompt, ckpt)
checkpoint = checkpoints[ckpt][0]
num_inference_steps = checkpoints[ckpt][1]
if loaded != num_inference_steps:
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.unet.load_state_dict(torch.load(hf_hub_download(repo, checkpoint), map_location="cuda"))
loaded = num_inference_steps
if num_inference_steps == 1:
results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0, timesteps=[399])
else:
results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
return results.images[0]
# Gradio Interface
with gr.Blocks(css=CSS) as demo:
gr.HTML("<h1><center>Adobe DMD2🦖</center></h1>")
gr.HTML("<p><center><a href='https://huggingface.co/tianweiy/DMD2'>DMD2</a> text-to-image generation</center></p>")
with gr.Group():
with gr.Row():
prompt = gr.Textbox(label='Enter your prompt (English)', scale=8)
ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '4-Step'], value='4-Step', interactive=True)
submit = gr.Button(scale=1, variant='primary')
img = gr.Image(label='DMD2 Generated Image')
prompt.submit(fn=generate_image,
inputs=[prompt, ckpt],
outputs=img,
)
submit.click(fn=generate_image,
inputs=[prompt, ckpt],
outputs=img,
)
demo.queue().launch() |