sincostanx's picture
fix minor bug
c790eb8
raw
history blame
9.35 kB
import gradio as gr
import torch
from pipeline import CustomPipeline, setup_scheduler
from PIL import Image
# from easydict import EasyDict as edict
original_pipe = None
original_config = None
device = None
# def run_dpm_demo(id, prompt, beta, num_inference_steps, guidance_scale, seed, enable_token_merging):
def run_dpm_demo(prompt, beta, num_inference_steps, guidance_scale, seed):
global original_pipe, original_config
pipe = CustomPipeline(**original_pipe.components)
seed = int(seed)
num_inference_steps = int(num_inference_steps)
scheduler = "DPM-Solver++"
params = {
"prompt": prompt,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"method": "dpm"
}
# without momentum (equivalent to DPM-Solver++)
pipe = setup_scheduler(pipe, scheduler, beta=1.0, original_config=original_config)
params["generator"] = torch.Generator(device=device).manual_seed(seed)
ori_image = pipe(**params).images[0]
# with momentum
pipe = setup_scheduler(pipe, scheduler, beta=beta, original_config=original_config)
params["generator"] = torch.Generator(device=device).manual_seed(seed)
image = pipe(**params).images[0]
ori_image.save("temp1.png")
image.save("temp2.png")
return [ori_image, image]
# def run_plms_demo(id, prompt, order, beta, momentum_type, num_inference_steps, guidance_scale, seed, enable_token_merging):
def run_plms_demo(prompt, order, beta, momentum_type, num_inference_steps, guidance_scale, seed):
global original_pipe, original_config
pipe = CustomPipeline(**original_pipe.components)
seed = int(seed)
num_inference_steps = int(num_inference_steps)
scheduler = "PLMS"
method = "hb" if momentum_type == "Polyak's heavy ball" else "nt"
params = {
"prompt": prompt,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"method": method
}
# without momentum (equivalent to PLMS)
pipe = setup_scheduler(pipe, scheduler, momentum_type=momentum_type, order=order, beta=1.0, original_config=original_config)
params["generator"] = torch.Generator(device=device).manual_seed(seed)
ori_image = pipe(**params).images[0]
# with momentum
pipe = setup_scheduler(pipe, scheduler, momentum_type=momentum_type, order=order, beta=beta, original_config=original_config)
params["generator"] = torch.Generator(device=device).manual_seed(seed)
image = pipe(**params).images[0]
return [ori_image, image]
# def run_ghvb_demo(id, prompt, order, beta, num_inference_steps, guidance_scale, seed, enable_token_merging):
def run_ghvb_demo(prompt, order, beta, num_inference_steps, guidance_scale, seed):
global original_pipe, original_config
pipe = CustomPipeline(**original_pipe.components)
seed = int(seed)
num_inference_steps = int(num_inference_steps)
scheduler = "GHVB"
params = {
"prompt": prompt,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"method": "ghvb"
}
# without momentum (equivalent to PLMS)
pipe = setup_scheduler(pipe, scheduler, order=order, beta=1.0, original_config=original_config)
params["generator"] = torch.Generator(device=device).manual_seed(seed)
ori_image = pipe(**params).images[0]
# with momentum
pipe = setup_scheduler(pipe, scheduler, order=order, beta=beta, original_config=original_config)
params["generator"] = torch.Generator(device=device).manual_seed(seed)
image = pipe(**params).images[0]
return [ori_image, image]
if __name__ == "__main__":
demo = gr.Blocks()
inputs = {}
outputs = {}
buttons = {}
list_models = [
]
with gr.Blocks() as demo:
gr.Markdown(
"""
# Momentum-Diffusion Demo
A novel sampling method for diffusion models based on momentum to reduce artifacts
"""
)
id = gr.Dropdown(list_models, label="Model ID", value="Linaqruf/anything-v3.0", allow_custom_value=True)
enable_token_merging = gr.Checkbox(label="Enable Token Merging", value=False)
# output = gr.Textbox()
buttons["select_model"] = gr.Button("Select")
with gr.Tab("GHVB", visible=False) as tab3:
prompt3 = gr.Textbox(label="Prompt", value="a cozy cafe", visible=False)
with gr.Row(visible=False) as row31:
order = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="order")
beta = gr.Slider(minimum=0, maximum=1, value=0.4, step=0.05, label="beta")
num_inference_steps = gr.Number(label="Number of steps", value=12)
guidance_scale = gr.Number(label="Guidance scale (cfg)", value=10)
seed = gr.Number(label="Seed", value=42)
with gr.Row(visible=False) as row32:
out1 = gr.Image(label="PLMS", interactive=False)
out2 = gr.Image(label="GHVB", interactive=False)
inputs["GHVB"] = [prompt3, order, beta, num_inference_steps, guidance_scale, seed]
outputs["GHVB"] = [out1, out2]
buttons["GHVB"] = gr.Button("Sample", visible=False)
with gr.Tab("PLMS", visible=False) as tab2:
prompt2 = gr.Textbox(label="Prompt", value="1girl", visible=False)
with gr.Row(visible=False) as row21:
order = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="order")
beta = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.05, label="beta")
momentum_type = gr.Dropdown(["Polyak's heavy ball", "Nesterov"], label="Momentum Type", value="Polyak's heavy ball")
num_inference_steps = gr.Number(label="Number of steps", value=10)
guidance_scale = gr.Number(label="Guidance scale (cfg)", value=10)
seed = gr.Number(label="Seed", value=42)
with gr.Row(visible=False) as row22:
out1 = gr.Image(label="Without momentum", interactive=False)
out2 = gr.Image(label="With momentum", interactive=False)
inputs["PLMS"] = [prompt2, order, beta, momentum_type, num_inference_steps, guidance_scale, seed]
outputs["PLMS"] = [out1, out2]
buttons["PLMS"] = gr.Button("Sample", visible=False)
with gr.Tab("DPM-Solver++", visible=False) as tab1:
prompt1 = gr.Textbox(label="Prompt", value="1girl", visible=False)
with gr.Row(visible=False) as row11:
beta = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.05, label="beta")
num_inference_steps = gr.Number(label="Number of steps", value=15)
guidance_scale = gr.Number(label="Guidance scale (cfg)", value=20)
seed = gr.Number(label="Seed", value=0)
with gr.Row(visible=False) as row12:
out1 = gr.Image(label="Without momentum", interactive=False)
out2 = gr.Image(label="With momentum", interactive=False)
inputs["DPM-Solver++"] = [prompt1, beta, num_inference_steps, guidance_scale, seed]
outputs["DPM-Solver++"] = [out1, out2]
buttons["DPM-Solver++"] = gr.Button("Sample", visible=False)
def prepare_model(id, enable_token_merging):
global original_pipe, original_config, device
if original_pipe is not None:
del original_pipe
original_pipe = CustomPipeline.from_pretrained(id)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
original_pipe = original_pipe.to(device)
if enable_token_merging:
import tomesd
tomesd.apply_patch(original_pipe, ratio=0.5)
print("Enabled Token merging.")
original_config = original_pipe.scheduler.config
print(type(original_pipe))
print(original_config)
return {
row11: gr.update(visible=True),
row12: gr.update(visible=True),
row21: gr.update(visible=True),
row22: gr.update(visible=True),
row31: gr.update(visible=True),
row32: gr.update(visible=True),
prompt1: gr.update(visible=True),
prompt2: gr.update(visible=True),
prompt3: gr.update(visible=True),
buttons["DPM-Solver++"]: gr.update(visible=True),
buttons["PLMS"]: gr.update(visible=True),
buttons["GHVB"]: gr.update(visible=True),
}
all_outputs = [row11, row12, row21, row22, row31, row32, prompt1, prompt2, prompt3, buttons["DPM-Solver++"], buttons["PLMS"], buttons["GHVB"]]
buttons["select_model"].click(prepare_model, inputs=[id, enable_token_merging], outputs=all_outputs)
buttons["DPM-Solver++"].click(run_dpm_demo, inputs=inputs["DPM-Solver++"], outputs=outputs["DPM-Solver++"])
buttons["PLMS"].click(run_plms_demo, inputs=inputs["PLMS"], outputs=outputs["PLMS"])
buttons["GHVB"].click(run_ghvb_demo, inputs=inputs["GHVB"], outputs=outputs["GHVB"])
demo.launch()