import gradio as gr import torch from pipeline import CustomPipeline, setup_scheduler from diffusers import StableDiffusionPipeline from PIL import Image # from easydict import EasyDict as edict original_pipe = None original_config = None device = None # def run_dpm_demo(id, prompt, beta, num_inference_steps, guidance_scale, seed, enable_token_merging): def run_dpm_demo(prompt, beta, num_inference_steps, guidance_scale, seed): global original_pipe, original_config pipe = CustomPipeline(**original_pipe.components) seed = int(seed) num_inference_steps = int(num_inference_steps) scheduler = "DPM-Solver++" params = { "prompt": prompt, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, "method": "dpm" } # without momentum (equivalent to DPM-Solver++) pipe = setup_scheduler(pipe, scheduler, beta=1.0, original_config=original_config) params["generator"] = torch.Generator(device=device).manual_seed(seed) ori_image = pipe(**params).images[0] # with momentum pipe = setup_scheduler(pipe, scheduler, beta=beta, original_config=original_config) params["generator"] = torch.Generator(device=device).manual_seed(seed) image = pipe(**params).images[0] ori_image.save("temp1.png") image.save("temp2.png") return [ori_image, image] # def run_plms_demo(id, prompt, order, beta, momentum_type, num_inference_steps, guidance_scale, seed, enable_token_merging): def run_plms_demo(prompt, order, beta, momentum_type, num_inference_steps, guidance_scale, seed): global original_pipe, original_config pipe = CustomPipeline(**original_pipe.components) seed = int(seed) num_inference_steps = int(num_inference_steps) scheduler = "PLMS" method = "hb" if momentum_type == "Polyak's heavy ball" else "nt" params = { "prompt": prompt, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, "method": method } # without momentum (equivalent to PLMS) pipe = setup_scheduler(pipe, scheduler, momentum_type=momentum_type, order=order, beta=1.0, original_config=original_config) params["generator"] = torch.Generator(device=device).manual_seed(seed) ori_image = pipe(**params).images[0] # with momentum pipe = setup_scheduler(pipe, scheduler, momentum_type=momentum_type, order=order, beta=beta, original_config=original_config) params["generator"] = torch.Generator(device=device).manual_seed(seed) image = pipe(**params).images[0] return [ori_image, image] # def run_ghvb_demo(id, prompt, order, beta, num_inference_steps, guidance_scale, seed, enable_token_merging): def run_ghvb_demo(prompt, order, beta, num_inference_steps, guidance_scale, seed): global original_pipe, original_config pipe = CustomPipeline(**original_pipe.components) seed = int(seed) num_inference_steps = int(num_inference_steps) scheduler = "GHVB" params = { "prompt": prompt, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, "method": "ghvb" } # without momentum (equivalent to PLMS) pipe = setup_scheduler(pipe, scheduler, order=order, beta=1.0, original_config=original_config) params["generator"] = torch.Generator(device=device).manual_seed(seed) ori_image = pipe(**params).images[0] # with momentum pipe = setup_scheduler(pipe, scheduler, order=order, beta=beta, original_config=original_config) params["generator"] = torch.Generator(device=device).manual_seed(seed) image = pipe(**params).images[0] return [ori_image, image] if __name__ == "__main__": demo = gr.Blocks() inputs = {} outputs = {} buttons = {} list_models = [ "Linaqruf/anything-v3.0", "runwayml/stable-diffusion-v1-5", "dreamlike-art/dreamlike-photoreal-2.0", ] for model_id in list_models: pipeline = StableDiffusionPipeline.from_pretrained(model_id) del pipeline print(f"Downloaded {model_id}") with gr.Blocks() as demo: gr.Markdown( """ # Momentum-Diffusion Demo A novel sampling method for diffusion models based on momentum to reduce artifacts """ ) id = gr.Dropdown(list_models, label="Model ID", value="Linaqruf/anything-v3.0", allow_custom_value=True) enable_token_merging = gr.Checkbox(label="Enable Token Merging", value=False) # output = gr.Textbox() buttons["select_model"] = gr.Button("Select") with gr.Tab("GHVB", visible=False) as tab3: prompt3 = gr.Textbox(label="Prompt", value="a cozy cafe", visible=False) with gr.Row(visible=False) as row31: order = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="order") beta = gr.Slider(minimum=0, maximum=1, value=0.4, step=0.05, label="beta") num_inference_steps = gr.Number(label="Number of steps", value=12) guidance_scale = gr.Number(label="Guidance scale (cfg)", value=10) seed = gr.Number(label="Seed", value=42) with gr.Row(visible=False) as row32: out1 = gr.Image(label="PLMS", interactive=False) out2 = gr.Image(label="GHVB", interactive=False) inputs["GHVB"] = [prompt3, order, beta, num_inference_steps, guidance_scale, seed] outputs["GHVB"] = [out1, out2] buttons["GHVB"] = gr.Button("Sample", visible=False) with gr.Tab("PLMS", visible=False) as tab2: prompt2 = gr.Textbox(label="Prompt", value="1girl", visible=False) with gr.Row(visible=False) as row21: order = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="order") beta = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.05, label="beta") momentum_type = gr.Dropdown(["Polyak's heavy ball", "Nesterov"], label="Momentum Type", value="Polyak's heavy ball") num_inference_steps = gr.Number(label="Number of steps", value=10) guidance_scale = gr.Number(label="Guidance scale (cfg)", value=10) seed = gr.Number(label="Seed", value=42) with gr.Row(visible=False) as row22: out1 = gr.Image(label="Without momentum", interactive=False) out2 = gr.Image(label="With momentum", interactive=False) inputs["PLMS"] = [prompt2, order, beta, momentum_type, num_inference_steps, guidance_scale, seed] outputs["PLMS"] = [out1, out2] buttons["PLMS"] = gr.Button("Sample", visible=False) with gr.Tab("DPM-Solver++", visible=False) as tab1: prompt1 = gr.Textbox(label="Prompt", value="1girl", visible=False) with gr.Row(visible=False) as row11: beta = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.05, label="beta") num_inference_steps = gr.Number(label="Number of steps", value=15) guidance_scale = gr.Number(label="Guidance scale (cfg)", value=20) seed = gr.Number(label="Seed", value=0) with gr.Row(visible=False) as row12: out1 = gr.Image(label="Without momentum", interactive=False) out2 = gr.Image(label="With momentum", interactive=False) inputs["DPM-Solver++"] = [prompt1, beta, num_inference_steps, guidance_scale, seed] outputs["DPM-Solver++"] = [out1, out2] buttons["DPM-Solver++"] = gr.Button("Sample", visible=False) def prepare_model(id, enable_token_merging): global original_pipe, original_config, device if original_pipe is not None: del original_pipe original_pipe = CustomPipeline.from_pretrained(id) device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") original_pipe = original_pipe.to(device) if enable_token_merging: import tomesd tomesd.apply_patch(original_pipe, ratio=0.5) print("Enabled Token merging.") original_config = original_pipe.scheduler.config print(type(original_pipe)) print(original_config) return { row11: gr.update(visible=True), row12: gr.update(visible=True), row21: gr.update(visible=True), row22: gr.update(visible=True), row31: gr.update(visible=True), row32: gr.update(visible=True), prompt1: gr.update(visible=True), prompt2: gr.update(visible=True), prompt3: gr.update(visible=True), buttons["DPM-Solver++"]: gr.update(visible=True), buttons["PLMS"]: gr.update(visible=True), buttons["GHVB"]: gr.update(visible=True), } all_outputs = [row11, row12, row21, row22, row31, row32, prompt1, prompt2, prompt3, buttons["DPM-Solver++"], buttons["PLMS"], buttons["GHVB"]] buttons["select_model"].click(prepare_model, inputs=[id, enable_token_merging], outputs=all_outputs) buttons["DPM-Solver++"].click(run_dpm_demo, inputs=inputs["DPM-Solver++"], outputs=outputs["DPM-Solver++"]) buttons["PLMS"].click(run_plms_demo, inputs=inputs["PLMS"], outputs=outputs["PLMS"]) buttons["GHVB"].click(run_ghvb_demo, inputs=inputs["GHVB"], outputs=outputs["GHVB"]) demo.launch()