Spaces:
Runtime error
Runtime error
controlnet now kicks out models to save memory
Browse files- lib/model_zoo/controlnet.py +27 -65
lib/model_zoo/controlnet.py
CHANGED
@@ -14,54 +14,6 @@ from .openaimodel import \
|
|
14 |
ResBlock, AttentionBlock, SpatialTransformer, \
|
15 |
Downsample, timestep_embedding
|
16 |
|
17 |
-
####################
|
18 |
-
# preprocess depth #
|
19 |
-
####################
|
20 |
-
|
21 |
-
# depth_model = None
|
22 |
-
|
23 |
-
# def unload_midas_model():
|
24 |
-
# global depth_model
|
25 |
-
# if depth_model is not None:
|
26 |
-
# depth_model = depth_model.cpu()
|
27 |
-
|
28 |
-
# def apply_midas(input_image, a=np.pi*2.0, bg_th=0.1, device='cpu'):
|
29 |
-
# import cv2
|
30 |
-
# from einops import rearrange
|
31 |
-
# from .controlnet_annotators.midas import MiDaSInference
|
32 |
-
# global depth_model
|
33 |
-
# if depth_model is None:
|
34 |
-
# depth_model = MiDaSInference(model_type="dpt_hybrid")
|
35 |
-
# depth_model = depth_model.to(device)
|
36 |
-
|
37 |
-
# assert input_image.ndim == 3
|
38 |
-
# image_depth = input_image
|
39 |
-
# with torch.no_grad():
|
40 |
-
# image_depth = torch.from_numpy(image_depth).float()
|
41 |
-
# image_depth = image_depth.to(device)
|
42 |
-
# image_depth = image_depth / 127.5 - 1.0
|
43 |
-
# image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
44 |
-
# depth = depth_model(image_depth)[0]
|
45 |
-
|
46 |
-
# depth_pt = depth.clone()
|
47 |
-
# depth_pt -= torch.min(depth_pt)
|
48 |
-
# depth_pt /= torch.max(depth_pt)
|
49 |
-
# depth_pt = depth_pt.cpu().numpy()
|
50 |
-
# depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
|
51 |
-
|
52 |
-
# depth_np = depth.cpu().numpy()
|
53 |
-
# x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
|
54 |
-
# y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
|
55 |
-
# z = np.ones_like(x) * a
|
56 |
-
# x[depth_pt < bg_th] = 0
|
57 |
-
# y[depth_pt < bg_th] = 0
|
58 |
-
# normal = np.stack([x, y, z], axis=2)
|
59 |
-
# normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
|
60 |
-
# normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
|
61 |
-
|
62 |
-
# return depth_image, normal_image
|
63 |
-
|
64 |
-
|
65 |
@register('controlnet')
|
66 |
class ControlNet(nn.Module):
|
67 |
def __init__(
|
@@ -360,37 +312,41 @@ class ControlNet(nn.Module):
|
|
360 |
return y_torch
|
361 |
|
362 |
elif type == 'depth':
|
363 |
-
from .controlnet_annotator.midas import apply_midas
|
364 |
y_list, _ = zip(*[apply_midas(input_image=np.array(xi), a=np.pi*2.0, device=device) for xi in x_list])
|
365 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
366 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
367 |
y_torch = y_torch.to(device).to(torch.float32)
|
|
|
368 |
return y_torch
|
369 |
|
370 |
elif type in ['hed', 'softedge_v11p']:
|
371 |
-
from .controlnet_annotator.hed import apply_hed
|
372 |
y_list = [apply_hed(np.array(xi), device=device) for xi in x_list]
|
373 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
374 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
375 |
y_torch = y_torch.to(device).to(torch.float32)
|
|
|
|
|
376 |
return y_torch
|
377 |
|
378 |
elif type in ['mlsd', 'mlsd_v11p']:
|
379 |
thr_v = kwargs.pop('thr_v', 0.1)
|
380 |
thr_d = kwargs.pop('thr_d', 0.1)
|
381 |
-
from .controlnet_annotator.mlsd import apply_mlsd
|
382 |
y_list = [apply_mlsd(np.array(xi), thr_v=thr_v, thr_d=thr_d, device=device) for xi in x_list]
|
383 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
384 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
385 |
y_torch = y_torch.to(device).to(torch.float32)
|
|
|
386 |
return y_torch
|
387 |
|
388 |
elif type == 'normal':
|
389 |
bg_th = kwargs.pop('bg_th', 0.4)
|
390 |
-
from .controlnet_annotator.midas import apply_midas
|
391 |
_, y_list = zip(*[apply_midas(input_image=np.array(xi), a=np.pi*2.0, bg_th=bg_th, device=device) for xi in x_list])
|
392 |
y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
|
393 |
-
|
394 |
return y_torch
|
395 |
|
396 |
elif type in ['openpose', 'openpose_v11p']:
|
@@ -403,6 +359,7 @@ class ControlNet(nn.Module):
|
|
403 |
y_list = [apply_openpose(np.array(xi)) for xi in x_list]
|
404 |
y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
|
405 |
y_torch = y_torch.to(device).to(torch.float32)
|
|
|
406 |
return y_torch
|
407 |
|
408 |
elif type in ['openpose_withface', 'openpose_withface_v11p']:
|
@@ -415,6 +372,7 @@ class ControlNet(nn.Module):
|
|
415 |
y_list = [apply_openpose(np.array(xi)) for xi in x_list]
|
416 |
y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
|
417 |
y_torch = y_torch.to(device).to(torch.float32)
|
|
|
418 |
return y_torch
|
419 |
|
420 |
elif type in ['openpose_withfacehand', 'openpose_withfacehand_v11p']:
|
@@ -427,6 +385,7 @@ class ControlNet(nn.Module):
|
|
427 |
y_list = [apply_openpose(np.array(xi)) for xi in x_list]
|
428 |
y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
|
429 |
y_torch = y_torch.to(device).to(torch.float32)
|
|
|
430 |
return y_torch
|
431 |
|
432 |
elif type == 'scribble':
|
@@ -454,21 +413,23 @@ class ControlNet(nn.Module):
|
|
454 |
return result
|
455 |
|
456 |
if method == 'hed':
|
457 |
-
from .controlnet_annotator.hed import apply_hed
|
458 |
y_list = [apply_hed(np.array(xi), device=device) for xi in x_list]
|
459 |
y_list = [make_scribble(yi) for yi in y_list]
|
460 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
461 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
462 |
y_torch = y_torch.to(device).to(torch.float32)
|
|
|
463 |
return y_torch
|
464 |
|
465 |
elif method == 'pidinet':
|
466 |
-
from .controlnet_annotator.pidinet import apply_pidinet
|
467 |
y_list = [apply_pidinet(np.array(xi), device=device) for xi in x_list]
|
468 |
y_list = [make_scribble(yi) for yi in y_list]
|
469 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
470 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
471 |
y_torch = y_torch.to(device).to(torch.float32)
|
|
|
472 |
return y_torch
|
473 |
|
474 |
elif method == 'xdog':
|
@@ -491,13 +452,14 @@ class ControlNet(nn.Module):
|
|
491 |
raise ValueError
|
492 |
|
493 |
elif type == 'seg':
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
|
|
|
14 |
ResBlock, AttentionBlock, SpatialTransformer, \
|
15 |
Downsample, timestep_embedding
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
@register('controlnet')
|
18 |
class ControlNet(nn.Module):
|
19 |
def __init__(
|
|
|
312 |
return y_torch
|
313 |
|
314 |
elif type == 'depth':
|
315 |
+
from .controlnet_annotator.midas import apply_midas, unload_midas_model
|
316 |
y_list, _ = zip(*[apply_midas(input_image=np.array(xi), a=np.pi*2.0, device=device) for xi in x_list])
|
317 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
318 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
319 |
y_torch = y_torch.to(device).to(torch.float32)
|
320 |
+
unload_midas_model()
|
321 |
return y_torch
|
322 |
|
323 |
elif type in ['hed', 'softedge_v11p']:
|
324 |
+
from .controlnet_annotator.hed import apply_hed, unload_hed_model
|
325 |
y_list = [apply_hed(np.array(xi), device=device) for xi in x_list]
|
326 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
327 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
328 |
y_torch = y_torch.to(device).to(torch.float32)
|
329 |
+
from .controlnet_annotator.midas import model as model_midas
|
330 |
+
unload_hed_model()
|
331 |
return y_torch
|
332 |
|
333 |
elif type in ['mlsd', 'mlsd_v11p']:
|
334 |
thr_v = kwargs.pop('thr_v', 0.1)
|
335 |
thr_d = kwargs.pop('thr_d', 0.1)
|
336 |
+
from .controlnet_annotator.mlsd import apply_mlsd, unload_mlsd_model
|
337 |
y_list = [apply_mlsd(np.array(xi), thr_v=thr_v, thr_d=thr_d, device=device) for xi in x_list]
|
338 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
339 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
340 |
y_torch = y_torch.to(device).to(torch.float32)
|
341 |
+
unload_mlsd_model()
|
342 |
return y_torch
|
343 |
|
344 |
elif type == 'normal':
|
345 |
bg_th = kwargs.pop('bg_th', 0.4)
|
346 |
+
from .controlnet_annotator.midas import apply_midas, unload_midas_model
|
347 |
_, y_list = zip(*[apply_midas(input_image=np.array(xi), a=np.pi*2.0, bg_th=bg_th, device=device) for xi in x_list])
|
348 |
y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
|
349 |
+
unload_midas_model()
|
350 |
return y_torch
|
351 |
|
352 |
elif type in ['openpose', 'openpose_v11p']:
|
|
|
359 |
y_list = [apply_openpose(np.array(xi)) for xi in x_list]
|
360 |
y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
|
361 |
y_torch = y_torch.to(device).to(torch.float32)
|
362 |
+
OpenposeModel.unload()
|
363 |
return y_torch
|
364 |
|
365 |
elif type in ['openpose_withface', 'openpose_withface_v11p']:
|
|
|
372 |
y_list = [apply_openpose(np.array(xi)) for xi in x_list]
|
373 |
y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
|
374 |
y_torch = y_torch.to(device).to(torch.float32)
|
375 |
+
OpenposeModel.unload()
|
376 |
return y_torch
|
377 |
|
378 |
elif type in ['openpose_withfacehand', 'openpose_withfacehand_v11p']:
|
|
|
385 |
y_list = [apply_openpose(np.array(xi)) for xi in x_list]
|
386 |
y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
|
387 |
y_torch = y_torch.to(device).to(torch.float32)
|
388 |
+
OpenposeModel.unload()
|
389 |
return y_torch
|
390 |
|
391 |
elif type == 'scribble':
|
|
|
413 |
return result
|
414 |
|
415 |
if method == 'hed':
|
416 |
+
from .controlnet_annotator.hed import apply_hed, unload_hed_model
|
417 |
y_list = [apply_hed(np.array(xi), device=device) for xi in x_list]
|
418 |
y_list = [make_scribble(yi) for yi in y_list]
|
419 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
420 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
421 |
y_torch = y_torch.to(device).to(torch.float32)
|
422 |
+
unload_hed_model()
|
423 |
return y_torch
|
424 |
|
425 |
elif method == 'pidinet':
|
426 |
+
from .controlnet_annotator.pidinet import apply_pidinet, unload_pid_model
|
427 |
y_list = [apply_pidinet(np.array(xi), device=device) for xi in x_list]
|
428 |
y_list = [make_scribble(yi) for yi in y_list]
|
429 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
430 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
431 |
y_torch = y_torch.to(device).to(torch.float32)
|
432 |
+
unload_pid_model()
|
433 |
return y_torch
|
434 |
|
435 |
elif method == 'xdog':
|
|
|
452 |
raise ValueError
|
453 |
|
454 |
elif type == 'seg':
|
455 |
+
assert False, "This part is broken"
|
456 |
+
# method = kwargs.pop('method', 'ufade20k')
|
457 |
+
# if method == 'ufade20k':
|
458 |
+
# from .controlnet_annotator.uniformer import apply_uniformer
|
459 |
+
# y_list = [apply_uniformer(np.array(xi), palette='ade20k', device=device) for xi in x_list]
|
460 |
+
# y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
461 |
+
# y_torch = y_torch.to(device).to(torch.float32)
|
462 |
+
# return y_torch
|
463 |
+
|
464 |
+
# else:
|
465 |
+
# raise ValueError
|