6Morpheus6's picture
UI improvement
2290da3 verified
raw
history blame
7.07 kB
import os
import gradio as gr
import torch
import devicetorch
import tempfile
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LCMScheduler
from diffusers.schedulers import TCDScheduler
#import spaces
from PIL import Image
checkpoints = {
"2-Step": ["pcm_{}_smallcfg_2step_converted.safetensors", 2, 0.0],
"4-Step": ["pcm_{}_smallcfg_4step_converted.safetensors", 4, 0.0],
"8-Step": ["pcm_{}_smallcfg_8step_converted.safetensors", 8, 0.0],
"16-Step": ["pcm_{}_smallcfg_16step_converted.safetensors", 16, 0.0],
"Normal CFG 4-Step": ["pcm_{}_normalcfg_4step_converted.safetensors", 4, 7.5],
"Normal CFG 8-Step": ["pcm_{}_normalcfg_8step_converted.safetensors", 8, 7.5],
"Normal CFG 16-Step": ["pcm_{}_normalcfg_16step_converted.safetensors", 16, 7.5],
"LCM-Like LoRA": [
"pcm_{}_lcmlike_lora_converted.safetensors",
4,
0.0,
],
}
loaded = None
device = devicetorch.get(torch)
#if torch.cuda.is_available():
# pipe_sdxl = StableDiffusionXLPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-base-1.0",
# torch_dtype=torch.float16,
# variant="fp16",
# ).to("cuda")
# pipe_sd15 = StableDiffusionPipeline.from_pretrained(
# "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
# ).to("cuda")
#@spaces.GPU(enable_queue=True)
def generate_image(
prompt,
ckpt,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
mode="sdxl",
):
global loaded
checkpoint = checkpoints[ckpt][0].format(mode)
guidance_scale = checkpoints[ckpt][2]
pipe_sdxl = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
).to(device)
pipe_sd15 = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
).to(device)
pipe = pipe_sdxl if mode == "sdxl" else pipe_sd15
pipe.load_lora_weights(
"wangfuyun/PCM_Weights", weight_name=checkpoint, subfolder=mode
)
if ckpt == "LCM-Like LoRA":
pipe.scheduler = LCMScheduler()
else:
pipe.scheduler = TCDScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
timestep_spacing="trailing",
)
results = pipe(
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale
)
gradio_temp_dir = os.environ.get('GRADIO_TEMP_DIR', tempfile.gettempdir())
temp_file_path = os.path.join(gradio_temp_dir, "image.png")
results.images[0].save(temp_file_path, format="PNG")
# if SAFETY_CHECKER:
# images, has_nsfw_concepts = check_nsfw_images(results.images)
# if any(has_nsfw_concepts):
# gr.Warning("NSFW content detected.")
# return Image.new("RGB", (512, 512))
# return images[0]
return results.images[0], temp_file_path
def clear_cache():
devicetorch.empty_cache(torch)
def update_steps(ckpt):
num_inference_steps = checkpoints[ckpt][1]
if ckpt == "LCM-Like LoRA":
return gr.update(interactive=True, value=num_inference_steps)
return gr.update(interactive=False, value=num_inference_steps)
css = """
.gradio-container {
max-width: 95vw !important;
margin: auto !important
}
.img img {
height: 70vh !important
}
#row-height {
height: 65px !important
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# Phased Consistency Model
Phased Consistency Model (PCM) is an image generation technique that addresses the limitations of the Latent Consistency Model (LCM) in high-resolution and text-conditioned image generation.
PCM outperforms LCM across various generation settings and achieves state-of-the-art results in both image and video generation.     [[paper](https://huggingface.co/papers/2405.18407)] [[arXiv](https://arxiv.org/abs/2405.18407)] [[code](https://github.com/G-U-N/Phased-Consistency-Model)] [[project page](https://g-u-n.github.io/projects/pcm)]
"""
)
with gr.Group():
with gr.Row():
prompt = gr.Textbox(label="Prompt", scale=4)
ckpt = gr.Dropdown(
label="Select inference steps",
choices=list(checkpoints.keys()),
value="2-Step",
)
steps = gr.Slider(
label="Number of Inference Steps",
minimum=1,
maximum=20,
step=1,
value=2,
interactive=False,
)
ckpt.change(
fn=update_steps,
inputs=[ckpt],
outputs=[steps],
queue=False,
show_progress=False,
)
with gr.Row():
submit_sdxl = gr.Button("Run on SDXL", scale=1)
submit_sd15 = gr.Button("Run on SD15", scale=1)
img = gr.Image(label="PCM Image", elem_classes="img")
download_image = gr.File(label="Download Image", file_count="single", interactive=False, elem_id="row-height")
gr.Examples(
examples=[
[" astronaut walking on the moon", "4-Step", 4],
[
"Photo of a dramatic cliffside lighthouse in a storm, waves crashing, symbol of guidance and resilience.",
"8-Step",
8,
],
[
"Vincent vangogh style, painting, a boy, clouds in the sky",
"Normal CFG 4-Step",
4,
],
[
"Echoes of a forgotten song drift across the moonlit sea, where a ghost ship sails, its spectral crew bound to an eternal quest for redemption.",
"4-Step",
4,
],
[
"Roger rabbit as a real person, photorealistic, cinematic.",
"16-Step",
16,
],
[
"tanding tall amidst the ruins, a stone golem awakens, vines and flowers sprouting from the crevices in its body.",
"LCM-Like LoRA",
4,
],
],
inputs=[prompt, ckpt, steps],
outputs=[img, download_image],
fn=generate_image,
#cache_examples="lazy",
)
gr.on(
fn=generate_image,
triggers=[ckpt.change, prompt.submit, submit_sdxl.click],
inputs=[prompt, ckpt, steps],
outputs=[img, download_image],
).then(
fn=clear_cache,
inputs=[],
outputs=None
)
gr.on(
fn=lambda *args: generate_image(*args, mode="sd15"),
triggers=[submit_sd15.click],
inputs=[prompt, ckpt, steps],
outputs=[img, download_image],
).then(
fn=clear_cache,
inputs=[],
outputs=None
)
demo.queue(api_open=False).launch(show_api=False)