#!/usr/bin/env python from __future__ import annotations import os import shlex import subprocess import sys import gradio as gr import PIL.Image import spaces import torch from diffusers import DPMSolverMultistepScheduler if os.getenv("SYSTEM") == "spaces": with open("patch") as f: subprocess.run(shlex.split("patch -p1"), cwd="multires_textual_inversion", stdin=f) sys.path.insert(0, "multires_textual_inversion") from pipeline import MultiResPipeline, load_learned_concepts DESCRIPTION = "# [Multiresolution Textual Inversion](https://github.com/giannisdaras/multires_textual_inversion)" DETAILS = """ - To run the Semi Resolution-Dependent sampler, use the format: ``. - To run the Fully Resolution-Dependent sampler, use the format: ``. - To run the Fixed Resolution sampler, use the format: ``. For this demo, only ``, `` and `` are available. Also, `number` should be an integer in [0, 9]. """ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_id = "ashllay/stable-diffusion-v1-5-archive" if device.type == "cpu": pipe = MultiResPipeline.from_pretrained(model_id) else: pipe = MultiResPipeline.from_pretrained(model_id, torch_dtype=torch.float16) pipe.scheduler = DPMSolverMultistepScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, trained_betas=None, prediction_type="epsilon", thresholding=False, algorithm_type="dpmsolver++", solver_type="midpoint", lower_order_final=True, ) string_to_param_dict = load_learned_concepts(pipe, "textual_inversion_outputs/") for k, v in list(string_to_param_dict.items()): string_to_param_dict[k] = v.to(device) pipe.to(device) pipe.text_encoder.to(device) @spaces.GPU def run(prompt: str, n_images: int, n_steps: int, seed: int) -> list[PIL.Image.Image]: generator = torch.Generator(device=device).manual_seed(seed) return pipe( [prompt] * n_images, string_to_param_dict, num_inference_steps=n_steps, generator=generator, ) with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Group(): with gr.Row(): prompt = gr.Textbox(label="Prompt") with gr.Row(): num_images = gr.Slider( label="Number of images", minimum=1, maximum=9, step=1, value=1, ) with gr.Row(): num_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=10, ) with gr.Row(): seed = gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, value=100) with gr.Row(): run_button = gr.Button() with gr.Column(): result = gr.Gallery(label="Result", object_fit="scale-down") with gr.Row(): with gr.Group(): fn = lambda x: run(x, 2, 10, 100) with gr.Row(): gr.Examples( label="Examples 1", examples=[ ["an image of "], ["an image of "], ["an image of "], ["an image of "], ], inputs=prompt, outputs=result, fn=fn, ) with gr.Row(): gr.Examples( label="Examples 2", examples=[ ["an image of a cat in the style of "], ["a painting of a dog in the style of "], ["a painting of a dog in the style of "], ["a painting of a in the style of "], ], inputs=prompt, outputs=result, fn=fn, ) with gr.Row(): gr.Examples( label="Examples 3", examples=[ ["an image of "], ["an image of "], ["an image of "], ], inputs=prompt, outputs=result, fn=fn, ) inputs = [ prompt, num_images, num_steps, seed, ] prompt.submit( fn=run, inputs=inputs, outputs=result, api_name=False, ) run_button.click( fn=run, inputs=inputs, outputs=result, api_name="run", ) with gr.Accordion("About available prompts", open=False): gr.Markdown(DETAILS) if __name__ == "__main__": demo.queue(max_size=20).launch()