File size: 5,255 Bytes
9a5dc44
 
 
 
88b57b3
9ca20bc
 
 
88b57b3
9a5dc44
9ca20bc
 
88b57b3
9ca20bc
9a5dc44
9ca20bc
 
 
 
 
 
 
9a5dc44
d818369
9a5dc44
d818369
9a5dc44
 
 
 
 
 
d818369
9a5dc44
9ca20bc
 
8557d69
9ca20bc
 
 
8557d69
9ca20bc
 
 
 
 
 
 
 
 
 
 
 
 
ef9ccc7
 
 
 
9ca20bc
 
 
 
 
 
 
 
 
 
 
9a5dc44
85f94cd
d818369
85f94cd
 
 
 
 
d818369
85f94cd
88b57b3
d818369
88b57b3
 
 
 
 
85f94cd
9ca20bc
 
 
 
 
 
 
85f94cd
d818369
85f94cd
9ca20bc
85f94cd
 
d818369
85f94cd
 
 
9ca20bc
85f94cd
 
d818369
85f94cd
d818369
 
 
 
85f94cd
 
 
 
 
 
 
d818369
85f94cd
d818369
 
 
 
85f94cd
 
 
 
 
 
 
d818369
85f94cd
d818369
 
 
85f94cd
 
 
 
 
 
 
 
 
 
 
 
 
9ca20bc
85f94cd
 
2423f10
85f94cd
 
9ca20bc
85f94cd
 
d818369
85f94cd
 
d818369
85f94cd
 
9ca20bc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#!/usr/bin/env python

from __future__ import annotations

import os
import shlex
import subprocess
import sys

import gradio as gr
import PIL.Image
import spaces
import torch
from diffusers import DPMSolverMultistepScheduler

if os.getenv("SYSTEM") == "spaces":
    with open("patch") as f:
        subprocess.run(shlex.split("patch -p1"), cwd="multires_textual_inversion", stdin=f)

sys.path.insert(0, "multires_textual_inversion")

from pipeline import MultiResPipeline, load_learned_concepts

DESCRIPTION = "# [Multiresolution Textual Inversion](https://github.com/giannisdaras/multires_textual_inversion)"

DETAILS = """
- To run the Semi Resolution-Dependent sampler, use the format: `<jane(number)>`.
- To run the Fully Resolution-Dependent sampler, use the format: `<jane[number]>`.
- To run the Fixed Resolution sampler, use the format: `<jane|number|>`.

For this demo, only `<jane>`, `<gta5-artwork>` and `<cat-toy>` are available.
Also, `number` should be an integer in [0, 9].
"""

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_id = "ashllay/stable-diffusion-v1-5-archive"
if device.type == "cpu":
    pipe = MultiResPipeline.from_pretrained(model_id)
else:
    pipe = MultiResPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler(
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    num_train_timesteps=1000,
    trained_betas=None,
    prediction_type="epsilon",
    thresholding=False,
    algorithm_type="dpmsolver++",
    solver_type="midpoint",
    lower_order_final=True,
)
string_to_param_dict = load_learned_concepts(pipe, "textual_inversion_outputs/")
for k, v in list(string_to_param_dict.items()):
    string_to_param_dict[k] = v.to(device)
pipe.to(device)
pipe.text_encoder.to(device)


@spaces.GPU
def run(prompt: str, n_images: int, n_steps: int, seed: int) -> list[PIL.Image.Image]:
    generator = torch.Generator(device=device).manual_seed(seed)
    return pipe(
        [prompt] * n_images,
        string_to_param_dict,
        num_inference_steps=n_steps,
        generator=generator,
    )


with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)

    with gr.Row():
        with gr.Group():
            with gr.Row():
                prompt = gr.Textbox(label="Prompt")
            with gr.Row():
                num_images = gr.Slider(
                    label="Number of images",
                    minimum=1,
                    maximum=9,
                    step=1,
                    value=1,
                )
            with gr.Row():
                num_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=10,
                )
            with gr.Row():
                seed = gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, value=100)
            with gr.Row():
                run_button = gr.Button()

        with gr.Column():
            result = gr.Gallery(label="Result", object_fit="scale-down")

    with gr.Row():
        with gr.Group():
            fn = lambda x: run(x, 2, 10, 100)
            with gr.Row():
                gr.Examples(
                    label="Examples 1",
                    examples=[
                        ["an image of <gta5-artwork(0)>"],
                        ["an image of <jane(0)>"],
                        ["an image of <jane(3)>"],
                        ["an image of <cat-toy(0)>"],
                    ],
                    inputs=prompt,
                    outputs=result,
                    fn=fn,
                )
            with gr.Row():
                gr.Examples(
                    label="Examples 2",
                    examples=[
                        ["an image of a cat in the style of <gta5-artwork(0)>"],
                        ["a painting of a dog in the style of <jane(0)>"],
                        ["a painting of a dog in the style of <jane(5)>"],
                        ["a painting of a <cat-toy(0)> in the style of <jane(3)>"],
                    ],
                    inputs=prompt,
                    outputs=result,
                    fn=fn,
                )
            with gr.Row():
                gr.Examples(
                    label="Examples 3",
                    examples=[
                        ["an image of <jane[0]>"],
                        ["an image of <jane|0|>"],
                        ["an image of <jane|3|>"],
                    ],
                    inputs=prompt,
                    outputs=result,
                    fn=fn,
                )

        inputs = [
            prompt,
            num_images,
            num_steps,
            seed,
        ]
        prompt.submit(
            fn=run,
            inputs=inputs,
            outputs=result,
            api_name=False,
        )
        run_button.click(
            fn=run,
            inputs=inputs,
            outputs=result,
            api_name="run",
        )

    with gr.Accordion("About available prompts", open=False):
        gr.Markdown(DETAILS)

if __name__ == "__main__":
    demo.queue(max_size=20).launch()