hysts's picture
hysts HF staff
Update the base model id
8557d69
raw
history blame contribute delete
No virus
5.26 kB
#!/usr/bin/env python
from __future__ import annotations
import os
import shlex
import subprocess
import sys
import gradio as gr
import PIL.Image
import spaces
import torch
from diffusers import DPMSolverMultistepScheduler
if os.getenv("SYSTEM") == "spaces":
with open("patch") as f:
subprocess.run(shlex.split("patch -p1"), cwd="multires_textual_inversion", stdin=f)
sys.path.insert(0, "multires_textual_inversion")
from pipeline import MultiResPipeline, load_learned_concepts
DESCRIPTION = "# [Multiresolution Textual Inversion](https://github.com/giannisdaras/multires_textual_inversion)"
DETAILS = """
- To run the Semi Resolution-Dependent sampler, use the format: `<jane(number)>`.
- To run the Fully Resolution-Dependent sampler, use the format: `<jane[number]>`.
- To run the Fixed Resolution sampler, use the format: `<jane|number|>`.
For this demo, only `<jane>`, `<gta5-artwork>` and `<cat-toy>` are available.
Also, `number` should be an integer in [0, 9].
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "ashllay/stable-diffusion-v1-5-archive"
if device.type == "cpu":
pipe = MultiResPipeline.from_pretrained(model_id)
else:
pipe = MultiResPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
trained_betas=None,
prediction_type="epsilon",
thresholding=False,
algorithm_type="dpmsolver++",
solver_type="midpoint",
lower_order_final=True,
)
string_to_param_dict = load_learned_concepts(pipe, "textual_inversion_outputs/")
for k, v in list(string_to_param_dict.items()):
string_to_param_dict[k] = v.to(device)
pipe.to(device)
pipe.text_encoder.to(device)
@spaces.GPU
def run(prompt: str, n_images: int, n_steps: int, seed: int) -> list[PIL.Image.Image]:
generator = torch.Generator(device=device).manual_seed(seed)
return pipe(
[prompt] * n_images,
string_to_param_dict,
num_inference_steps=n_steps,
generator=generator,
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Group():
with gr.Row():
prompt = gr.Textbox(label="Prompt")
with gr.Row():
num_images = gr.Slider(
label="Number of images",
minimum=1,
maximum=9,
step=1,
value=1,
)
with gr.Row():
num_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=10,
)
with gr.Row():
seed = gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, value=100)
with gr.Row():
run_button = gr.Button()
with gr.Column():
result = gr.Gallery(label="Result", object_fit="scale-down")
with gr.Row():
with gr.Group():
fn = lambda x: run(x, 2, 10, 100)
with gr.Row():
gr.Examples(
label="Examples 1",
examples=[
["an image of <gta5-artwork(0)>"],
["an image of <jane(0)>"],
["an image of <jane(3)>"],
["an image of <cat-toy(0)>"],
],
inputs=prompt,
outputs=result,
fn=fn,
)
with gr.Row():
gr.Examples(
label="Examples 2",
examples=[
["an image of a cat in the style of <gta5-artwork(0)>"],
["a painting of a dog in the style of <jane(0)>"],
["a painting of a dog in the style of <jane(5)>"],
["a painting of a <cat-toy(0)> in the style of <jane(3)>"],
],
inputs=prompt,
outputs=result,
fn=fn,
)
with gr.Row():
gr.Examples(
label="Examples 3",
examples=[
["an image of <jane[0]>"],
["an image of <jane|0|>"],
["an image of <jane|3|>"],
],
inputs=prompt,
outputs=result,
fn=fn,
)
inputs = [
prompt,
num_images,
num_steps,
seed,
]
prompt.submit(
fn=run,
inputs=inputs,
outputs=result,
api_name=False,
)
run_button.click(
fn=run,
inputs=inputs,
outputs=result,
api_name="run",
)
with gr.Accordion("About available prompts", open=False):
gr.Markdown(DETAILS)
if __name__ == "__main__":
demo.queue(max_size=20).launch()