File size: 2,353 Bytes
8b1e96d
0cffd40
8b1e96d
0cffd40
 
 
 
 
8b1e96d
0cffd40
8b1e96d
 
f286ae5
 
8b1e96d
 
 
ec35e66
 
 
 
 
 
8b1e96d
37a0a26
8b1e96d
 
 
 
 
f286ae5
8b1e96d
 
 
 
 
 
 
 
6e61ba6
8b1e96d
 
0cffd40
8b1e96d
0cffd40
8b1e96d
0cffd40
 
 
8b1e96d
0cffd40
556fb50
8b1e96d
3eaeeea
8b1e96d
0cffd40
8b1e96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces
import gradio as gr    
import torch
import PIL

# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "tianweiy/DMD2"
checkpoints = {
    "1-Step" : ["dmd2_sdxl_1step_unet.bin", 1],
    "4-Step" : ["dmd2_sdxl_4step_unet.bin", 4],
}
loaded = None

CSS = """
.gradio-container {
  max-width: 690px !important;
}
"""

# Ensure model and scheduler are initialized in GPU-enabled function
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
if torch.cuda.is_available():
    pipe = DiffusionPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")


# Function 
@spaces.GPU()
def generate_image(prompt, ckpt):
    global loaded
    print(prompt, ckpt)

    checkpoint = checkpoints[ckpt][0]
    num_inference_steps = checkpoints[ckpt][1]

    if loaded != num_inference_steps:
        unet.load_state_dict(torch.load(hf_hub_download(repo, checkpoint), map_location="cuda"))
        pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon")
        loaded = num_inference_steps
        
    results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)

    return results.images[0]



# Gradio Interface

with gr.Blocks(css=CSS) as demo:
    gr.HTML("<h1><center>Adobe DMD2🦖</center></h1>")
    gr.HTML("<p><center><a href='https://huggingface.co/tianweiy/DMD2'>DMD2</a> text-to-image generation</center></p>")
    with gr.Group():
        with gr.Row():
            prompt = gr.Textbox(label='Enter your prompt (English)', scale=8)
            ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
            submit = gr.Button(scale=1, variant='primary')
    img = gr.Image(label='DMD2 Generated Image')

    prompt.submit(fn=generate_image,
                 inputs=[prompt, ckpt],
                 outputs=img,
                 )
    submit.click(fn=generate_image,
                 inputs=[prompt, ckpt],
                 outputs=img,
                 )
    
demo.queue().launch()