susunghong commited on
Commit
4a7978f
·
1 Parent(s): 7d16311

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -0
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import random
5
+
6
+ import gradio as gr
7
+ import torch
8
+ from PIL import Image, ImageOps
9
+ from diffusers import StableDiffusionSAGPipeline
10
+
11
+
12
+ help_text = """
13
+ Self-Attention Guidance
14
+ """
15
+
16
+
17
+ examples = [
18
+ [
19
+ ' ',
20
+ 50,
21
+ False,
22
+ 8978,
23
+ 7.5,
24
+ 1.0,
25
+ ],
26
+ [
27
+ '.',
28
+ 50,
29
+ False,
30
+ 8978,
31
+ 7.5,
32
+ 1.0,
33
+ ],
34
+ [
35
+ 'A cute Scottish Fold playing with a ball',
36
+ 50,
37
+ False,
38
+ 8978,
39
+ 5.0,
40
+ 1.0,
41
+ ],
42
+ [
43
+ 'A person with a happy dog',
44
+ 50,
45
+ False,
46
+ 8978,
47
+ 5.0,
48
+ 1.0,
49
+ ],
50
+ ]
51
+
52
+
53
+ model_id = "runwayml/stable-diffusion-v1-5"
54
+
55
+ def main():
56
+ pipe = StableDiffusionSAGPipeline.from_pretrained(model_id)#, torch_dtype=torch.float16)
57
+
58
+ def generate(
59
+ prompt: str,
60
+ steps: int,
61
+ randomize_seed: bool,
62
+ seed: int,
63
+ cfg_scale: float,
64
+ sag_scale: float,
65
+ ):
66
+ seed = random.randint(0, 100000) if randomize_seed else seed
67
+
68
+ generator = torch.manual_seed(seed)
69
+ ori_image = pipe(prompt, generator=generator, guidance_scale=cfg_scale, sag_scale=0.75).images[0]
70
+ generator = torch.manual_seed(seed)
71
+ sag_image = pipe(prompt, generator=generator, guidance_scale=cfg_scale, sag_scale=0.75).images[0]
72
+ return [seed, ori_image, sag_image]
73
+
74
+ def reset():
75
+ return [0, "Randomize Seed", 1371, 5.0, 0.75, None, None]
76
+
77
+ with gr.Blocks() as demo:
78
+ gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">
79
+ Self-Attention Guidance
80
+ """)
81
+ with gr.Row():
82
+ with gr.Column(scale=5):
83
+ prompt = gr.Textbox(lines=1, label="Enter your prompt", interactive=True)
84
+ with gr.Column(scale=1, min_width=60):
85
+ generate_button = gr.Button("Generate")
86
+ with gr.Column(scale=1, min_width=60):
87
+ reset_button = gr.Button("Reset")
88
+
89
+ with gr.Row():
90
+ ori_image = gr.Image(label="CFG", type="pil", interactive=False)
91
+ sag_image = gr.Image(label="SAG + CFG", type="pil", interactive=False)
92
+ ori_image.style(height=512, width=512)
93
+ sag_image.style(height=512, width=512)
94
+
95
+ with gr.Row():
96
+ steps = gr.Number(value=50, precision=0, label="Steps", interactive=True)
97
+ randomize_seed = gr.Radio(
98
+ ["Fix Seed", "Randomize Seed"],
99
+ value="Fix Seed",
100
+ type="index",
101
+ show_label=False,
102
+ interactive=True,
103
+ )
104
+ seed = gr.Number(value=8978, precision=0, label="Seed", interactive=True)
105
+
106
+ with gr.Row():
107
+ cfg_scale = gr.Slider(
108
+ label="Guidance Scale", minimum=0, maximum=10, value=5.0, step=0.1
109
+ )
110
+ sag_scale = gr.Slider(
111
+ label="Self-Attention Guidance Scale", minimum=0, maximum=1.0, value=0.75, step=0.05
112
+ )
113
+
114
+ ex = gr.Examples(
115
+ examples=examples,
116
+ fn=generate,
117
+ inputs=[
118
+ prompt,
119
+ steps,
120
+ randomize_seed,
121
+ seed,
122
+ cfg_scale,
123
+ sag_scale,
124
+ ],
125
+ outputs=[seed, ori_image, sag_image],
126
+ cache_examples=True,
127
+ preprocess=False,
128
+ postprocess=False
129
+ )
130
+
131
+ gr.Markdown(help_text)
132
+
133
+ generate_button.click(
134
+ fn=generate,
135
+ inputs=[
136
+ prompt,
137
+ steps,
138
+ randomize_seed,
139
+ seed,
140
+ cfg_scale,
141
+ sag_scale,
142
+ ],
143
+ outputs=[seed, ori_image, sag_image],
144
+ )
145
+ reset_button.click(
146
+ fn=reset,
147
+ inputs=[],
148
+ outputs=[steps, randomize_seed, seed, cfg_scale, sag_scale, ori_image, sag_image],
149
+ )
150
+
151
+ demo.queue(concurrency_count=1)
152
+ demo.launch(share=False)
153
+
154
+
155
+ if __name__ == "__main__":
156
+ main()