DmitrMakeev commited on
Commit
6c7cd3d
·
verified ·
1 Parent(s): 4aa5967

Create app_sketch.py

Browse files
Files changed (1) hide show
  1. app_sketch.py +166 -0
app_sketch.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import gradio as gr
4
+ import PIL.Image
5
+ import torch
6
+ import torchvision.transforms.functional as TF
7
+
8
+ from model import Model
9
+ from utils import (
10
+ DEFAULT_STYLE_NAME,
11
+ MAX_SEED,
12
+ STYLE_NAMES,
13
+ apply_style,
14
+ randomize_seed_fn,
15
+ )
16
+
17
+
18
+ def create_demo(model: Model) -> gr.Blocks:
19
+ def run(
20
+ image: PIL.Image.Image,
21
+ prompt: str,
22
+ negative_prompt: str,
23
+ style_name: str = DEFAULT_STYLE_NAME,
24
+ num_steps: int = 25,
25
+ guidance_scale: float = 5,
26
+ adapter_conditioning_scale: float = 0.8,
27
+ adapter_conditioning_factor: float = 0.8,
28
+ seed: int = 0,
29
+ progress=gr.Progress(track_tqdm=True),
30
+ ) -> PIL.Image.Image:
31
+ image = image.convert("RGB")
32
+ image = TF.to_tensor(image) > 0.5
33
+ image = TF.to_pil_image(image.to(torch.float32))
34
+
35
+ prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
36
+
37
+ return model.run(
38
+ image=image,
39
+ prompt=prompt,
40
+ negative_prompt=negative_prompt,
41
+ adapter_name="sketch",
42
+ num_inference_steps=num_steps,
43
+ guidance_scale=guidance_scale,
44
+ adapter_conditioning_scale=adapter_conditioning_scale,
45
+ adapter_conditioning_factor=adapter_conditioning_factor,
46
+ seed=seed,
47
+ apply_preprocess=False,
48
+ )[1]
49
+
50
+ with gr.Blocks() as demo:
51
+ with gr.Row():
52
+ with gr.Column():
53
+ with gr.Group():
54
+ image = gr.Image(
55
+ source="canvas",
56
+ tool="sketch",
57
+ type="pil",
58
+ image_mode="L",
59
+ invert_colors=True,
60
+ shape=(1024, 1024),
61
+ brush_radius=4,
62
+ height=600,
63
+ )
64
+ prompt = gr.Textbox(label="Prompt")
65
+ style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
66
+ run_button = gr.Button("Run")
67
+ with gr.Accordion("Advanced options", open=False):
68
+ negative_prompt = gr.Textbox(
69
+ label="Negative prompt",
70
+ value=" extra digit, fewer digits, cropped, worst quality, low quality, glitch, deformed, mutated, ugly, disfigured",
71
+ )
72
+ num_steps = gr.Slider(
73
+ label="Number of steps",
74
+ minimum=1,
75
+ maximum=50,
76
+ step=1,
77
+ value=25,
78
+ )
79
+ guidance_scale = gr.Slider(
80
+ label="Guidance scale",
81
+ minimum=0.1,
82
+ maximum=10.0,
83
+ step=0.1,
84
+ value=5,
85
+ )
86
+ adapter_conditioning_scale = gr.Slider(
87
+ label="Adapter conditioning scale",
88
+ minimum=0.5,
89
+ maximum=1,
90
+ step=0.1,
91
+ value=0.8,
92
+ )
93
+ adapter_conditioning_factor = gr.Slider(
94
+ label="Adapter conditioning factor",
95
+ info="Fraction of timesteps for which adapter should be applied",
96
+ minimum=0.5,
97
+ maximum=1,
98
+ step=0.1,
99
+ value=0.8,
100
+ )
101
+ seed = gr.Slider(
102
+ label="Seed",
103
+ minimum=0,
104
+ maximum=MAX_SEED,
105
+ step=1,
106
+ value=0,
107
+ )
108
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
109
+ with gr.Column():
110
+ result = gr.Image(label="Result", height=600)
111
+
112
+ inputs = [
113
+ image,
114
+ prompt,
115
+ negative_prompt,
116
+ style,
117
+ num_steps,
118
+ guidance_scale,
119
+ adapter_conditioning_scale,
120
+ adapter_conditioning_factor,
121
+ seed,
122
+ ]
123
+ prompt.submit(
124
+ fn=randomize_seed_fn,
125
+ inputs=[seed, randomize_seed],
126
+ outputs=seed,
127
+ queue=False,
128
+ api_name=False,
129
+ ).then(
130
+ fn=run,
131
+ inputs=inputs,
132
+ outputs=result,
133
+ api_name=False,
134
+ )
135
+ negative_prompt.submit(
136
+ fn=randomize_seed_fn,
137
+ inputs=[seed, randomize_seed],
138
+ outputs=seed,
139
+ queue=False,
140
+ api_name=False,
141
+ ).then(
142
+ fn=run,
143
+ inputs=inputs,
144
+ outputs=result,
145
+ api_name=False,
146
+ )
147
+ run_button.click(
148
+ fn=randomize_seed_fn,
149
+ inputs=[seed, randomize_seed],
150
+ outputs=seed,
151
+ queue=False,
152
+ api_name=False,
153
+ ).then(
154
+ fn=run,
155
+ inputs=inputs,
156
+ outputs=result,
157
+ api_name=False,
158
+ )
159
+
160
+ return demo
161
+
162
+
163
+ if __name__ == "__main__":
164
+ model = Model("sketch")
165
+ demo = create_demo(model)
166
+ demo.queue(max_size=20).launch()