vumichien commited on
Commit
75618ca
·
1 Parent(s): a041f35

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +230 -0
app.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import streamlit as st
3
+ from streamlit_drawable_canvas import st_canvas
4
+ from streamlit_lottie import st_lottie
5
+ from streamlit_option_menu import option_menu
6
+ import requests
7
+ import os
8
+ os.system('git clone https://github.com/lllyasviel/ControlNet.git')
9
+ os.chdir('/home/user/app/ControlNet')
10
+
11
+ from share import *
12
+ import config
13
+
14
+ import cv2
15
+ import einops
16
+ import gradio as gr
17
+ import numpy as np
18
+ import torch
19
+ import random
20
+
21
+ from huggingface_hub import hf_hub_download
22
+ from pytorch_lightning import seed_everything
23
+ from annotator.util import resize_image, HWC3
24
+ from annotator.hed import HEDdetector, nms
25
+ from cldm.model import create_model, load_state_dict
26
+ from cldm.ddim_hacked import DDIMSampler
27
+
28
+ st.set_page_config(
29
+ page_title="ControllNet",
30
+ page_icon="🖥️",
31
+ layout="wide",
32
+ initial_sidebar_state="expanded"
33
+ )
34
+
35
+ @st.cache_resource
36
+ def load_model():
37
+ model_path = hf_hub_download('lllyasviel/ControlNet', 'models/control_sd15_scribble.pth')
38
+ model = create_model('./models/cldm_v15.yaml').cpu()
39
+ model.load_state_dict(load_state_dict(model_path, location='cuda'))
40
+ model = model.cuda()
41
+ return model
42
+
43
+
44
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
45
+ with torch.no_grad():
46
+
47
+ input_image = HWC3(input_image[:, :, 0])
48
+ detected_map = apply_hed(resize_image(input_image, detect_resolution))
49
+ detected_map = HWC3(detected_map)
50
+ img = resize_image(input_image, image_resolution)
51
+ H, W, C = img.shape
52
+
53
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
54
+ detected_map = nms(detected_map, 127, 3.0)
55
+ detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
56
+ detected_map[detected_map > 4] = 255
57
+ detected_map[detected_map < 255] = 0
58
+
59
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
60
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
61
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
62
+
63
+ if seed == -1:
64
+ seed = random.randint(0, 65535)
65
+ seed_everything(seed)
66
+
67
+ if config.save_memory:
68
+ model.low_vram_shift(is_diffusing=False)
69
+
70
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
71
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
72
+ shape = (4, H // 8, W // 8)
73
+
74
+ if config.save_memory:
75
+ model.low_vram_shift(is_diffusing=True)
76
+
77
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
78
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
79
+ shape, cond, verbose=False, eta=eta,
80
+ unconditional_guidance_scale=scale,
81
+ unconditional_conditioning=un_cond)
82
+
83
+ if config.save_memory:
84
+ model.low_vram_shift(is_diffusing=False)
85
+
86
+ x_samples = model.decode_first_stage(samples)
87
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
88
+
89
+ results = [x_samples[i] for i in range(num_samples)]
90
+ # return [255 - detected_map] + results
91
+ return results
92
+
93
+ @st.cache_data
94
+ def load_lottieurl(url: str):
95
+ r = requests.get(url)
96
+ if r.status_code != 200:
97
+ return None
98
+ return r.json()
99
+
100
+ model = load_model()
101
+ ddim_sampler = DDIMSampler(model)
102
+ apply_hed = HEDdetector()
103
+
104
+ def main():
105
+ lottie_penguin = load_lottieurl('https://assets5.lottiefiles.com/datafiles/B8q1AyJ5t1wb5S8a2ggTqYNxS1WiKN9mjS76TBpw/articulation/articulation.json')
106
+ st.header("Generate image with ControllNet")
107
+ with st.sidebar:
108
+ st_lottie(lottie_penguin, height=200)
109
+ choose = option_menu("Generate image", ["Upload", "Canvas"],
110
+ icons=['collection', 'file-plus'],
111
+ menu_icon="infinity", default_index=0,
112
+ styles={
113
+ "container": {"padding": ".0rem", "font-size": "14px"},
114
+ "nav-link-selected": {"color": "#000000", "font-size": "16px"},
115
+ }
116
+ )
117
+ st.sidebar.markdown(
118
+ """
119
+ ___
120
+ <p style='text-align: center'>
121
+ ControlNet is as fast as fine-tuning a diffusion model to support additional input conditions
122
+ <br/>
123
+ <a href="https://arxiv.org/abs/2302.05543" target="_blank">Article</a>
124
+ </p>
125
+ <p style='text-align: center; font-size: 14px;'>
126
+ Spaces creating by
127
+ <br/>
128
+ <a href="https://www.linkedin.com/in/vumichien/" target="_blank">Chien Vu</a>
129
+ <br/>
130
+ <img src='https://visitor-badge.glitch.me/badge?page_id=Canvas.ControlNet' alt='visitor badge'>
131
+ </p>
132
+ """,
133
+ unsafe_allow_html=True,
134
+ )
135
+ if choose == 'Upload':
136
+ with st.form(key='generate_form'):
137
+ upload_file = st.file_uploader("Upload image", type=["png", "jpg", "jpeg"])
138
+ prompt = st.text_input(label="Prompt", placeholder='Type your instruction')
139
+ col11, col12 = st.columns(2)
140
+ with st.expander('Advanced option', expanded=False):
141
+ col21, col22 = st.columns(2)
142
+ with col21:
143
+ image_resolution = st.slider(label="Image Resolution", min_value=256, max_value=512, value=512, step=256)
144
+ strength = st.slider(label="Control Strength", min_value=0.0, max_value=2.0, value=1.0, step=0.01)
145
+ guess_mode = st.checkbox(label='Guess Mode', value=False)
146
+ detect_resolution = st.slider(label="HED Resolution", min_value=128, max_value=1024, value=512, step=1)
147
+ ddim_steps = st.slider(label="Steps", min_value=1, max_value=100, value=20, step=1)
148
+ with col22:
149
+ scale = st.slider(label="Guidance Scale", min_value=0.1, max_value=30.0, value=9.0, step=0.1)
150
+ seed = st.number_input(label="Seed", min_value=-1, value=-1)
151
+ eta = st.number_input(label="eta (DDIM)", value=0.0)
152
+ a_prompt = st.text_input(label="Added Prompt", value='best quality, extremely detailed')
153
+ n_prompt = st.text_input(label="Negative Prompt",
154
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
155
+ # generate_button = st.button('Generate Image')
156
+ generate_button = st.form_submit_button(label='Generate Image')
157
+
158
+ if upload_file:
159
+ # file_bytes = np.asarray(bytearray(upload_file.read()), dtype=np.uint8)
160
+ # imageBGR = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
161
+ # input_image = cv2.cvtColor(imageBGR , cv2.COLOR_BGR2RGB)
162
+ input_image = np.asarray(Image.open(upload_file))
163
+ print("input_image", input_image.shape)
164
+
165
+ if generate_button:
166
+ with st.spinner(text=f"It may take up to 1 minute under high load. Generating images..."):
167
+ results = process(input_image, prompt, a_prompt, n_prompt, 1, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta)
168
+ print("input_image", input_image.shape)
169
+ print("results", results[0].shape)
170
+ H, W, C = input_image.shape
171
+ # output_image = cv2.resize(results[0], (W, H), interpolation=cv2.INTER_AREA)
172
+ col11.image(input_image, channels='RGB', width=None, clamp=False, caption='Input image')
173
+ col12.image(results[0], channels='RGB', width=None, clamp=False, caption='Generated image')
174
+
175
+ elif choose == 'Canvas':
176
+ with st.form(key='canvas_form'):
177
+ # Specify canvas parameters in application
178
+ stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 3)
179
+ stroke_color = st.sidebar.color_picker("Stroke color hex: ")
180
+ bg_color = st.sidebar.color_picker("Background color hex: ", "#eee")
181
+ bg_height = st.sidebar.slider("Canvas height", min_value=256, max_value=512, value=512, step=64)
182
+ bg_width = st.sidebar.slider("Canvas width", min_value=256, max_value=512, value=512, step=64)
183
+ realtime_update = st.sidebar.checkbox("Update in realtime", True)
184
+
185
+ # Create a canvas component
186
+ col31, col32 = st.columns(2)
187
+ with col31:
188
+ canvas_result = st_canvas(
189
+ fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
190
+ stroke_width=stroke_width,
191
+ stroke_color=stroke_color,
192
+ background_color=bg_color,
193
+ background_image=None,
194
+ update_streamlit=realtime_update,
195
+ height=bg_height,
196
+ width=bg_width,
197
+ drawing_mode="freedraw",
198
+ point_display_radius=0,
199
+ key="canvas",
200
+ )
201
+ prompt = st.text_input(label="Prompt", placeholder='Type your instruction')
202
+
203
+ with st.expander('Advanced option', expanded=False):
204
+ col41, col42 = st.columns(2)
205
+ with col41:
206
+ image_resolution = st.slider(label="Image Resolution", min_value=256, max_value=512, value=512, step=256)
207
+ strength = st.slider(label="Control Strength", min_value=0.0, max_value=2.0, value=1.0, step=0.01)
208
+ guess_mode = st.checkbox(label='Guess Mode', value=False)
209
+ detect_resolution = st.slider(label="HED Resolution", min_value=128, max_value=1024, value=512, step=1)
210
+ ddim_steps = st.slider(label="Steps", min_value=1, max_value=100, value=20, step=1)
211
+ with col42:
212
+ scale = st.slider(label="Guidance Scale", min_value=0.1, max_value=30.0, value=9.0, step=0.1)
213
+ seed = st.number_input(label="Seed", min_value=-1, value=-1)
214
+ eta = st.number_input(label="eta (DDIM)", value=0.0)
215
+ a_prompt = st.text_input(label="Added Prompt", value='best quality, extremely detailed')
216
+ n_prompt = st.text_input(label="Negative Prompt",
217
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
218
+
219
+ # Do something interesting with the image data and paths
220
+ generate_button = st.form_submit_button(label='Generate Image')
221
+ if canvas_result.image_data is not None:
222
+ input_image = canvas_result.image_data
223
+ with st.spinner(text=f"It may take up to 1 minute under high load. Generating images..."):
224
+ results = process(input_image, prompt, a_prompt, n_prompt, 1, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta)
225
+ H, W, C = input_image.shape
226
+ output_image = cv2.resize(results[0], (W, H), interpolation=cv2.INTER_AREA)
227
+ col32.image(output_image, channels='RGB', width=384, clamp=True, caption='Generated image')
228
+
229
+ if __name__ == '__main__':
230
+ main()