BertChristiaens
commited on
Commit
•
41e92f0
1
Parent(s):
3790166
refactor
Browse files
models.py
CHANGED
@@ -29,7 +29,6 @@ def flush():
|
|
29 |
|
30 |
class ControlNetPipeline:
|
31 |
def __init__(self):
|
32 |
-
print(torch.__version__)
|
33 |
self.in_use = False
|
34 |
self.controlnet = ControlNetModel.from_pretrained(
|
35 |
"BertChristiaens/controlnet-seg-room", torch_dtype=torch.float16)
|
@@ -43,7 +42,6 @@ class ControlNetPipeline:
|
|
43 |
|
44 |
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
45 |
self.pipe.enable_xformers_memory_efficient_attention()
|
46 |
-
# self.pipe.enable_attention_slicing("max")
|
47 |
self.pipe = self.pipe.to("cuda")
|
48 |
|
49 |
self.waiting_queue = []
|
@@ -72,21 +70,44 @@ class ControlNetPipeline:
|
|
72 |
self.waiting_queue.pop(0)
|
73 |
flush()
|
74 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
90 |
|
91 |
|
92 |
def convolution(mask: Image.Image, size=9) -> Image:
|
@@ -154,53 +175,10 @@ def get_inpainting_pipeline() -> StableDiffusionInpaintPipeline:
|
|
154 |
Returns:
|
155 |
StableDiffusionInpaintPipeline: inpainting pipeline
|
156 |
"""
|
157 |
-
pipe =
|
158 |
-
"stabilityai/stable-diffusion-2-inpainting",
|
159 |
-
torch_dtype=torch.float16,
|
160 |
-
safety_checker=None,
|
161 |
-
)
|
162 |
-
|
163 |
-
pipe.enable_xformers_memory_efficient_attention()
|
164 |
-
pipe = pipe.to("cuda")
|
165 |
-
|
166 |
return pipe
|
167 |
|
168 |
|
169 |
-
def make_grid_parameters(grid_search: Dict, params: Dict) -> List[Dict]:
|
170 |
-
"""Method to make grid parameters
|
171 |
-
Args:
|
172 |
-
grid_search (Dict): grid search parameters
|
173 |
-
params (Dict): fixed parameters
|
174 |
-
Returns:
|
175 |
-
List[Dict]: grid parameters
|
176 |
-
"""
|
177 |
-
options = []
|
178 |
-
|
179 |
-
for k in range(len(grid_search['generator'])):
|
180 |
-
for i in range(len(grid_search['strength'])):
|
181 |
-
for j in range(len(grid_search['guidance_scale'])):
|
182 |
-
options.append({'strength': grid_search['strength'][i],
|
183 |
-
'guidance_scale': grid_search['guidance_scale'][j],
|
184 |
-
'generator': grid_search['generator'][k],
|
185 |
-
**params
|
186 |
-
})
|
187 |
-
return options
|
188 |
-
|
189 |
-
|
190 |
-
def make_captions(options: List[Dict]) -> List[str]:
|
191 |
-
"""Method to make captions
|
192 |
-
Args:
|
193 |
-
options (List[Dict]): grid parameters
|
194 |
-
Returns:
|
195 |
-
List[str]: captions
|
196 |
-
"""
|
197 |
-
captions = []
|
198 |
-
for option in options:
|
199 |
-
captions.append(
|
200 |
-
f"strength {option['strength']}, guidance {option['guidance_scale']}, steps {option['num_inference_steps']}")
|
201 |
-
return captions
|
202 |
-
|
203 |
-
|
204 |
@torch.inference_mode()
|
205 |
def make_image_controlnet(image: np.ndarray,
|
206 |
mask_image: np.ndarray,
|
@@ -219,49 +197,30 @@ def make_image_controlnet(image: np.ndarray,
|
|
219 |
List[Image.Image]: list of generated images
|
220 |
"""
|
221 |
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
torch.cuda.empty_cache()
|
226 |
-
images = []
|
227 |
-
|
228 |
-
common_parameters = {'prompt': positive_prompt,
|
229 |
-
'negative_prompt': negative_prompt,
|
230 |
-
'num_inference_steps': 30,
|
231 |
-
'controlnet_conditioning_scale': 1.1,
|
232 |
-
'controlnet_conditioning_scale_decay': 0.96,
|
233 |
-
'controlnet_steps': 28,
|
234 |
-
}
|
235 |
-
|
236 |
-
grid_search = {'strength': [1.00, ],
|
237 |
-
'guidance_scale': [7.0],
|
238 |
-
'generator': [[torch.Generator(device="cuda").manual_seed(seed+i)] for i in range(1)],
|
239 |
-
}
|
240 |
-
|
241 |
-
prompt_settings = make_grid_parameters(grid_search, common_parameters)
|
242 |
-
|
243 |
|
244 |
-
mask_image = Image.fromarray((mask_image * 255).astype(np.uint8)).convert("RGB")
|
245 |
image = Image.fromarray(image).convert("RGB")
|
246 |
controlnet_conditioning_image = Image.fromarray(controlnet_conditioning_image).convert("RGB").filter(ImageFilter.GaussianBlur(radius = 9))
|
247 |
-
|
248 |
mask_image_postproc = convolution(mask_image)
|
249 |
|
250 |
-
with catchtime("Controlnet generation total"):
|
251 |
-
for _, setting in enumerate(prompt_settings):
|
252 |
-
st.success(f"{pipe.queue_size} images in the queue, can take up to {(pipe.queue_size)+1 * 20} seconds")
|
253 |
-
with catchtime("Controlnet generation"):
|
254 |
-
generated_image = pipe(
|
255 |
-
**setting,
|
256 |
-
image=image,
|
257 |
-
mask_image=mask_image,
|
258 |
-
controlnet_conditioning_image=controlnet_conditioning_image,
|
259 |
-
).images[0]
|
260 |
-
generated_image = postprocess_image_masking(
|
261 |
-
generated_image, image, mask_image_postproc)
|
262 |
-
images.append(generated_image)
|
263 |
|
264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
|
266 |
|
267 |
@torch.inference_mode()
|
@@ -278,27 +237,19 @@ def make_inpainting(positive_prompt: str,
|
|
278 |
Returns:
|
279 |
List[Image.Image]: list of generated images
|
280 |
"""
|
|
|
281 |
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
with catchtime("Inpainting generation"):
|
294 |
-
image_ = pipe(image=image,
|
295 |
-
mask_image=Image.fromarray((mask_image * 255).astype(np.uint8)),
|
296 |
-
height=HEIGHT,
|
297 |
-
width=WIDTH,
|
298 |
-
**common_parameters
|
299 |
-
).images[0]
|
300 |
-
images.append(image_)
|
301 |
-
return images
|
302 |
|
303 |
|
304 |
@torch.inference_mode()
|
@@ -316,9 +267,8 @@ def segment_image(image: Image) -> Image:
|
|
316 |
outputs = image_segmentor(pixel_values)
|
317 |
|
318 |
seg = image_processor.post_process_semantic_segmentation(
|
319 |
-
outputs, target_sizes=[image.size[::-1]])
|
320 |
-
|
321 |
-
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
|
322 |
palette = np.array(ade_palette())
|
323 |
for label, color in enumerate(palette):
|
324 |
color_seg[seg == label, :] = color
|
|
|
29 |
|
30 |
class ControlNetPipeline:
|
31 |
def __init__(self):
|
|
|
32 |
self.in_use = False
|
33 |
self.controlnet = ControlNetModel.from_pretrained(
|
34 |
"BertChristiaens/controlnet-seg-room", torch_dtype=torch.float16)
|
|
|
42 |
|
43 |
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
44 |
self.pipe.enable_xformers_memory_efficient_attention()
|
|
|
45 |
self.pipe = self.pipe.to("cuda")
|
46 |
|
47 |
self.waiting_queue = []
|
|
|
70 |
self.waiting_queue.pop(0)
|
71 |
flush()
|
72 |
return results
|
73 |
+
|
74 |
+
class SDPipeline:
|
75 |
+
def __init__(self):
|
76 |
+
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
77 |
+
"stabilityai/stable-diffusion-2-inpainting",
|
78 |
+
torch_dtype=torch.float16,
|
79 |
+
safety_checker=None,
|
80 |
+
)
|
81 |
|
82 |
+
self.pipe.enable_xformers_memory_efficient_attention()
|
83 |
+
self.pipe = self.pipe.to("cuda")
|
84 |
+
|
85 |
+
self.waiting_queue = []
|
86 |
+
self.count = 0
|
87 |
+
|
88 |
+
@property
|
89 |
+
def queue_size(self):
|
90 |
+
return len(self.waiting_queue)
|
91 |
+
|
92 |
+
def __call__(self, **kwargs):
|
93 |
+
self.count += 1
|
94 |
+
number = self.count
|
95 |
|
96 |
+
self.waiting_queue.append(number)
|
97 |
+
|
98 |
+
# wait until the next number in the queue is the current number
|
99 |
+
while self.waiting_queue[0] != number:
|
100 |
+
print(f"Wait for your turn {number} in queue {self.waiting_queue}")
|
101 |
+
time.sleep(0.5)
|
102 |
+
pass
|
103 |
+
|
104 |
+
# it's your turn, so remove the number from the queue
|
105 |
+
# and call the function
|
106 |
+
print("It's the turn of", self.count)
|
107 |
+
results = self.pipe(**kwargs)
|
108 |
+
self.waiting_queue.pop(0)
|
109 |
+
flush()
|
110 |
+
return results
|
111 |
|
112 |
|
113 |
def convolution(mask: Image.Image, size=9) -> Image:
|
|
|
175 |
Returns:
|
176 |
StableDiffusionInpaintPipeline: inpainting pipeline
|
177 |
"""
|
178 |
+
pipe = SDPipeline()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
return pipe
|
180 |
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
@torch.inference_mode()
|
183 |
def make_image_controlnet(image: np.ndarray,
|
184 |
mask_image: np.ndarray,
|
|
|
197 |
List[Image.Image]: list of generated images
|
198 |
"""
|
199 |
|
200 |
+
pipe = get_controlnet()
|
201 |
+
flush()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
|
|
203 |
image = Image.fromarray(image).convert("RGB")
|
204 |
controlnet_conditioning_image = Image.fromarray(controlnet_conditioning_image).convert("RGB").filter(ImageFilter.GaussianBlur(radius = 9))
|
205 |
+
mask_image = Image.fromarray((mask_image * 255).astype(np.uint8)).convert("RGB")
|
206 |
mask_image_postproc = convolution(mask_image)
|
207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
+
st.success(f"{pipe.queue_size} images in the queue, can take up to {(pipe.queue_size)+1 * 10} seconds")
|
210 |
+
generated_image = pipe(
|
211 |
+
prompt=positive_prompt,
|
212 |
+
negative_prompt=negative_prompt,
|
213 |
+
num_inference_steps=20,
|
214 |
+
strength=[1.00, ],
|
215 |
+
guidance_scale=[7.0],
|
216 |
+
generator=[torch.Generator(device="cuda").manual_seed(seed)],
|
217 |
+
image=image,
|
218 |
+
mask_image=mask_image,
|
219 |
+
controlnet_conditioning_image=controlnet_conditioning_image,
|
220 |
+
).images[0]
|
221 |
+
generated_image = postprocess_image_masking(generated_image, image, mask_image_postproc)
|
222 |
+
|
223 |
+
return generated_image
|
224 |
|
225 |
|
226 |
@torch.inference_mode()
|
|
|
237 |
Returns:
|
238 |
List[Image.Image]: list of generated images
|
239 |
"""
|
240 |
+
pipe = get_inpainting_pipeline()
|
241 |
|
242 |
+
flush()
|
243 |
+
image_ = pipe(image=image,
|
244 |
+
mask_image=Image.fromarray((mask_image * 255).astype(np.uint8)),
|
245 |
+
prompt=positive_prompt,
|
246 |
+
negative_prompt=negative_prompt,
|
247 |
+
num_inference_steps=20,
|
248 |
+
height=HEIGHT,
|
249 |
+
width=WIDTH,
|
250 |
+
**common_parameters
|
251 |
+
).images[0]
|
252 |
+
return image_
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
|
255 |
@torch.inference_mode()
|
|
|
267 |
outputs = image_segmentor(pixel_values)
|
268 |
|
269 |
seg = image_processor.post_process_semantic_segmentation(
|
270 |
+
outputs, target_sizes=[image.size[::-1]])[0]
|
271 |
+
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
|
|
272 |
palette = np.array(ade_palette())
|
273 |
for label, color in enumerate(palette):
|
274 |
color_seg[seg == label, :] = color
|