Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from pipeline import CustomPipeline, setup_scheduler | |
from PIL import Image | |
# from easydict import EasyDict as edict | |
original_pipe = None | |
original_config = None | |
device = None | |
# def run_dpm_demo(id, prompt, beta, num_inference_steps, guidance_scale, seed, enable_token_merging): | |
def run_dpm_demo(prompt, beta, num_inference_steps, guidance_scale, seed): | |
global original_pipe, original_config | |
pipe = CustomPipeline(**original_pipe.components) | |
seed = int(seed) | |
num_inference_steps = int(num_inference_steps) | |
scheduler = "DPM-Solver++" | |
params = { | |
"prompt": prompt, | |
"num_inference_steps": num_inference_steps, | |
"guidance_scale": guidance_scale, | |
"method": "dpm" | |
} | |
# without momentum (equivalent to DPM-Solver++) | |
pipe = setup_scheduler(pipe, scheduler, beta=1.0, original_config=original_config) | |
params["generator"] = torch.Generator(device=device).manual_seed(seed) | |
ori_image = pipe(**params).images[0] | |
# with momentum | |
pipe = setup_scheduler(pipe, scheduler, beta=beta, original_config=original_config) | |
params["generator"] = torch.Generator(device=device).manual_seed(seed) | |
image = pipe(**params).images[0] | |
ori_image.save("temp1.png") | |
image.save("temp2.png") | |
return [ori_image, image] | |
# def run_plms_demo(id, prompt, order, beta, momentum_type, num_inference_steps, guidance_scale, seed, enable_token_merging): | |
def run_plms_demo(prompt, order, beta, momentum_type, num_inference_steps, guidance_scale, seed): | |
global original_pipe, original_config | |
pipe = CustomPipeline(**original_pipe.components) | |
seed = int(seed) | |
num_inference_steps = int(num_inference_steps) | |
scheduler = "PLMS" | |
method = "hb" if momentum_type == "Polyak's heavy ball" else "nt" | |
params = { | |
"prompt": prompt, | |
"num_inference_steps": num_inference_steps, | |
"guidance_scale": guidance_scale, | |
"method": method | |
} | |
# without momentum (equivalent to PLMS) | |
pipe = setup_scheduler(pipe, scheduler, momentum_type=momentum_type, order=order, beta=1.0, original_config=original_config) | |
params["generator"] = torch.Generator(device=device).manual_seed(seed) | |
ori_image = pipe(**params).images[0] | |
# with momentum | |
pipe = setup_scheduler(pipe, scheduler, momentum_type=momentum_type, order=order, beta=beta, original_config=original_config) | |
params["generator"] = torch.Generator(device=device).manual_seed(seed) | |
image = pipe(**params).images[0] | |
return [ori_image, image] | |
# def run_ghvb_demo(id, prompt, order, beta, num_inference_steps, guidance_scale, seed, enable_token_merging): | |
def run_ghvb_demo(prompt, order, beta, num_inference_steps, guidance_scale, seed): | |
global original_pipe, original_config | |
pipe = CustomPipeline(**original_pipe.components) | |
seed = int(seed) | |
num_inference_steps = int(num_inference_steps) | |
scheduler = "GHVB" | |
params = { | |
"prompt": prompt, | |
"num_inference_steps": num_inference_steps, | |
"guidance_scale": guidance_scale, | |
"method": "ghvb" | |
} | |
# without momentum (equivalent to PLMS) | |
pipe = setup_scheduler(pipe, scheduler, order=order, beta=1.0, original_config=original_config) | |
params["generator"] = torch.Generator(device=device).manual_seed(seed) | |
ori_image = pipe(**params).images[0] | |
# with momentum | |
pipe = setup_scheduler(pipe, scheduler, order=order, beta=beta, original_config=original_config) | |
params["generator"] = torch.Generator(device=device).manual_seed(seed) | |
image = pipe(**params).images[0] | |
return [ori_image, image] | |
if __name__ == "__main__": | |
demo = gr.Blocks() | |
inputs = {} | |
outputs = {} | |
buttons = {} | |
list_models = [ | |
] | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Momentum-Diffusion Demo | |
A novel sampling method for diffusion models based on momentum to reduce artifacts | |
""" | |
) | |
id = gr.Dropdown(list_models, label="Model ID", value="Linaqruf/anything-v3.0", allow_custom_value=True) | |
enable_token_merging = gr.Checkbox(label="Enable Token Merging", value=False) | |
# output = gr.Textbox() | |
buttons["select_model"] = gr.Button("Select") | |
with gr.Tab("GHVB", visible=False) as tab3: | |
prompt3 = gr.Textbox(label="Prompt", value="a cozy cafe", visible=False) | |
with gr.Row(visible=False) as row31: | |
order = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="order") | |
beta = gr.Slider(minimum=0, maximum=1, value=0.4, step=0.05, label="beta") | |
num_inference_steps = gr.Number(label="Number of steps", value=12) | |
guidance_scale = gr.Number(label="Guidance scale (cfg)", value=10) | |
seed = gr.Number(label="Seed", value=42) | |
with gr.Row(visible=False) as row32: | |
out1 = gr.Image(label="PLMS", interactive=False) | |
out2 = gr.Image(label="GHVB", interactive=False) | |
inputs["GHVB"] = [prompt3, order, beta, num_inference_steps, guidance_scale, seed] | |
outputs["GHVB"] = [out1, out2] | |
buttons["GHVB"] = gr.Button("Sample", visible=False) | |
with gr.Tab("PLMS", visible=False) as tab2: | |
prompt2 = gr.Textbox(label="Prompt", value="1girl", visible=False) | |
with gr.Row(visible=False) as row21: | |
order = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="order") | |
beta = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.05, label="beta") | |
momentum_type = gr.Dropdown(["Polyak's heavy ball", "Nesterov"], label="Momentum Type", value="Polyak's heavy ball") | |
num_inference_steps = gr.Number(label="Number of steps", value=10) | |
guidance_scale = gr.Number(label="Guidance scale (cfg)", value=10) | |
seed = gr.Number(label="Seed", value=42) | |
with gr.Row(visible=False) as row22: | |
out1 = gr.Image(label="Without momentum", interactive=False) | |
out2 = gr.Image(label="With momentum", interactive=False) | |
inputs["PLMS"] = [prompt2, order, beta, momentum_type, num_inference_steps, guidance_scale, seed] | |
outputs["PLMS"] = [out1, out2] | |
buttons["PLMS"] = gr.Button("Sample", visible=False) | |
with gr.Tab("DPM-Solver++", visible=False) as tab1: | |
prompt1 = gr.Textbox(label="Prompt", value="1girl", visible=False) | |
with gr.Row(visible=False) as row11: | |
beta = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.05, label="beta") | |
num_inference_steps = gr.Number(label="Number of steps", value=15) | |
guidance_scale = gr.Number(label="Guidance scale (cfg)", value=20) | |
seed = gr.Number(label="Seed", value=0) | |
with gr.Row(visible=False) as row12: | |
out1 = gr.Image(label="Without momentum", interactive=False) | |
out2 = gr.Image(label="With momentum", interactive=False) | |
inputs["DPM-Solver++"] = [prompt1, beta, num_inference_steps, guidance_scale, seed] | |
outputs["DPM-Solver++"] = [out1, out2] | |
buttons["DPM-Solver++"] = gr.Button("Sample", visible=False) | |
def prepare_model(id, enable_token_merging): | |
global original_pipe, original_config, device | |
if original_pipe is not None: | |
del original_pipe | |
original_pipe = CustomPipeline.from_pretrained(id) | |
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") | |
original_pipe = original_pipe.to(device) | |
if enable_token_merging: | |
import tomesd | |
tomesd.apply_patch(original_pipe, ratio=0.5) | |
print("Enabled Token merging.") | |
original_config = original_pipe.scheduler.config | |
print(type(original_pipe)) | |
print(original_config) | |
return { | |
row11: gr.update(visible=True), | |
row12: gr.update(visible=True), | |
row21: gr.update(visible=True), | |
row22: gr.update(visible=True), | |
row31: gr.update(visible=True), | |
row32: gr.update(visible=True), | |
prompt1: gr.update(visible=True), | |
prompt2: gr.update(visible=True), | |
prompt3: gr.update(visible=True), | |
buttons["DPM-Solver++"]: gr.update(visible=True), | |
buttons["PLMS"]: gr.update(visible=True), | |
buttons["GHVB"]: gr.update(visible=True), | |
} | |
all_outputs = [row11, row12, row21, row22, row31, row32, prompt1, prompt2, prompt3, buttons["DPM-Solver++"], buttons["PLMS"], buttons["GHVB"]] | |
buttons["select_model"].click(prepare_model, inputs=[id, enable_token_merging], outputs=all_outputs) | |
buttons["DPM-Solver++"].click(run_dpm_demo, inputs=inputs["DPM-Solver++"], outputs=outputs["DPM-Solver++"]) | |
buttons["PLMS"].click(run_plms_demo, inputs=inputs["PLMS"], outputs=outputs["PLMS"]) | |
buttons["GHVB"].click(run_ghvb_demo, inputs=inputs["GHVB"], outputs=outputs["GHVB"]) | |
demo.launch() |