3v324v23 commited on
Commit
127e696
·
1 Parent(s): ee8a48b

controlnet now kicks out models to save memory

Browse files
Files changed (1) hide show
  1. lib/model_zoo/controlnet.py +27 -65
lib/model_zoo/controlnet.py CHANGED
@@ -14,54 +14,6 @@ from .openaimodel import \
14
  ResBlock, AttentionBlock, SpatialTransformer, \
15
  Downsample, timestep_embedding
16
 
17
- ####################
18
- # preprocess depth #
19
- ####################
20
-
21
- # depth_model = None
22
-
23
- # def unload_midas_model():
24
- # global depth_model
25
- # if depth_model is not None:
26
- # depth_model = depth_model.cpu()
27
-
28
- # def apply_midas(input_image, a=np.pi*2.0, bg_th=0.1, device='cpu'):
29
- # import cv2
30
- # from einops import rearrange
31
- # from .controlnet_annotators.midas import MiDaSInference
32
- # global depth_model
33
- # if depth_model is None:
34
- # depth_model = MiDaSInference(model_type="dpt_hybrid")
35
- # depth_model = depth_model.to(device)
36
-
37
- # assert input_image.ndim == 3
38
- # image_depth = input_image
39
- # with torch.no_grad():
40
- # image_depth = torch.from_numpy(image_depth).float()
41
- # image_depth = image_depth.to(device)
42
- # image_depth = image_depth / 127.5 - 1.0
43
- # image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
44
- # depth = depth_model(image_depth)[0]
45
-
46
- # depth_pt = depth.clone()
47
- # depth_pt -= torch.min(depth_pt)
48
- # depth_pt /= torch.max(depth_pt)
49
- # depth_pt = depth_pt.cpu().numpy()
50
- # depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
51
-
52
- # depth_np = depth.cpu().numpy()
53
- # x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
54
- # y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
55
- # z = np.ones_like(x) * a
56
- # x[depth_pt < bg_th] = 0
57
- # y[depth_pt < bg_th] = 0
58
- # normal = np.stack([x, y, z], axis=2)
59
- # normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
60
- # normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
61
-
62
- # return depth_image, normal_image
63
-
64
-
65
  @register('controlnet')
66
  class ControlNet(nn.Module):
67
  def __init__(
@@ -360,37 +312,41 @@ class ControlNet(nn.Module):
360
  return y_torch
361
 
362
  elif type == 'depth':
363
- from .controlnet_annotator.midas import apply_midas
364
  y_list, _ = zip(*[apply_midas(input_image=np.array(xi), a=np.pi*2.0, device=device) for xi in x_list])
365
  y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
366
  y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
367
  y_torch = y_torch.to(device).to(torch.float32)
 
368
  return y_torch
369
 
370
  elif type in ['hed', 'softedge_v11p']:
371
- from .controlnet_annotator.hed import apply_hed
372
  y_list = [apply_hed(np.array(xi), device=device) for xi in x_list]
373
  y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
374
  y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
375
  y_torch = y_torch.to(device).to(torch.float32)
 
 
376
  return y_torch
377
 
378
  elif type in ['mlsd', 'mlsd_v11p']:
379
  thr_v = kwargs.pop('thr_v', 0.1)
380
  thr_d = kwargs.pop('thr_d', 0.1)
381
- from .controlnet_annotator.mlsd import apply_mlsd
382
  y_list = [apply_mlsd(np.array(xi), thr_v=thr_v, thr_d=thr_d, device=device) for xi in x_list]
383
  y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
384
  y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
385
  y_torch = y_torch.to(device).to(torch.float32)
 
386
  return y_torch
387
 
388
  elif type == 'normal':
389
  bg_th = kwargs.pop('bg_th', 0.4)
390
- from .controlnet_annotator.midas import apply_midas
391
  _, y_list = zip(*[apply_midas(input_image=np.array(xi), a=np.pi*2.0, bg_th=bg_th, device=device) for xi in x_list])
392
  y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
393
- y_torch = y_torch.to(device).to(torch.float32)
394
  return y_torch
395
 
396
  elif type in ['openpose', 'openpose_v11p']:
@@ -403,6 +359,7 @@ class ControlNet(nn.Module):
403
  y_list = [apply_openpose(np.array(xi)) for xi in x_list]
404
  y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
405
  y_torch = y_torch.to(device).to(torch.float32)
 
406
  return y_torch
407
 
408
  elif type in ['openpose_withface', 'openpose_withface_v11p']:
@@ -415,6 +372,7 @@ class ControlNet(nn.Module):
415
  y_list = [apply_openpose(np.array(xi)) for xi in x_list]
416
  y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
417
  y_torch = y_torch.to(device).to(torch.float32)
 
418
  return y_torch
419
 
420
  elif type in ['openpose_withfacehand', 'openpose_withfacehand_v11p']:
@@ -427,6 +385,7 @@ class ControlNet(nn.Module):
427
  y_list = [apply_openpose(np.array(xi)) for xi in x_list]
428
  y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
429
  y_torch = y_torch.to(device).to(torch.float32)
 
430
  return y_torch
431
 
432
  elif type == 'scribble':
@@ -454,21 +413,23 @@ class ControlNet(nn.Module):
454
  return result
455
 
456
  if method == 'hed':
457
- from .controlnet_annotator.hed import apply_hed
458
  y_list = [apply_hed(np.array(xi), device=device) for xi in x_list]
459
  y_list = [make_scribble(yi) for yi in y_list]
460
  y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
461
  y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
462
  y_torch = y_torch.to(device).to(torch.float32)
 
463
  return y_torch
464
 
465
  elif method == 'pidinet':
466
- from .controlnet_annotator.pidinet import apply_pidinet
467
  y_list = [apply_pidinet(np.array(xi), device=device) for xi in x_list]
468
  y_list = [make_scribble(yi) for yi in y_list]
469
  y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
470
  y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
471
  y_torch = y_torch.to(device).to(torch.float32)
 
472
  return y_torch
473
 
474
  elif method == 'xdog':
@@ -491,13 +452,14 @@ class ControlNet(nn.Module):
491
  raise ValueError
492
 
493
  elif type == 'seg':
494
- method = kwargs.pop('method', 'ufade20k')
495
- if method == 'ufade20k':
496
- from .controlnet_annotator.uniformer import apply_uniformer
497
- y_list = [apply_uniformer(np.array(xi), palette='ade20k', device=device) for xi in x_list]
498
- y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
499
- y_torch = y_torch.to(device).to(torch.float32)
500
- return y_torch
501
-
502
- else:
503
- raise ValueError
 
 
14
  ResBlock, AttentionBlock, SpatialTransformer, \
15
  Downsample, timestep_embedding
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  @register('controlnet')
18
  class ControlNet(nn.Module):
19
  def __init__(
 
312
  return y_torch
313
 
314
  elif type == 'depth':
315
+ from .controlnet_annotator.midas import apply_midas, unload_midas_model
316
  y_list, _ = zip(*[apply_midas(input_image=np.array(xi), a=np.pi*2.0, device=device) for xi in x_list])
317
  y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
318
  y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
319
  y_torch = y_torch.to(device).to(torch.float32)
320
+ unload_midas_model()
321
  return y_torch
322
 
323
  elif type in ['hed', 'softedge_v11p']:
324
+ from .controlnet_annotator.hed import apply_hed, unload_hed_model
325
  y_list = [apply_hed(np.array(xi), device=device) for xi in x_list]
326
  y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
327
  y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
328
  y_torch = y_torch.to(device).to(torch.float32)
329
+ from .controlnet_annotator.midas import model as model_midas
330
+ unload_hed_model()
331
  return y_torch
332
 
333
  elif type in ['mlsd', 'mlsd_v11p']:
334
  thr_v = kwargs.pop('thr_v', 0.1)
335
  thr_d = kwargs.pop('thr_d', 0.1)
336
+ from .controlnet_annotator.mlsd import apply_mlsd, unload_mlsd_model
337
  y_list = [apply_mlsd(np.array(xi), thr_v=thr_v, thr_d=thr_d, device=device) for xi in x_list]
338
  y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
339
  y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
340
  y_torch = y_torch.to(device).to(torch.float32)
341
+ unload_mlsd_model()
342
  return y_torch
343
 
344
  elif type == 'normal':
345
  bg_th = kwargs.pop('bg_th', 0.4)
346
+ from .controlnet_annotator.midas import apply_midas, unload_midas_model
347
  _, y_list = zip(*[apply_midas(input_image=np.array(xi), a=np.pi*2.0, bg_th=bg_th, device=device) for xi in x_list])
348
  y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
349
+ unload_midas_model()
350
  return y_torch
351
 
352
  elif type in ['openpose', 'openpose_v11p']:
 
359
  y_list = [apply_openpose(np.array(xi)) for xi in x_list]
360
  y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
361
  y_torch = y_torch.to(device).to(torch.float32)
362
+ OpenposeModel.unload()
363
  return y_torch
364
 
365
  elif type in ['openpose_withface', 'openpose_withface_v11p']:
 
372
  y_list = [apply_openpose(np.array(xi)) for xi in x_list]
373
  y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
374
  y_torch = y_torch.to(device).to(torch.float32)
375
+ OpenposeModel.unload()
376
  return y_torch
377
 
378
  elif type in ['openpose_withfacehand', 'openpose_withfacehand_v11p']:
 
385
  y_list = [apply_openpose(np.array(xi)) for xi in x_list]
386
  y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
387
  y_torch = y_torch.to(device).to(torch.float32)
388
+ OpenposeModel.unload()
389
  return y_torch
390
 
391
  elif type == 'scribble':
 
413
  return result
414
 
415
  if method == 'hed':
416
+ from .controlnet_annotator.hed import apply_hed, unload_hed_model
417
  y_list = [apply_hed(np.array(xi), device=device) for xi in x_list]
418
  y_list = [make_scribble(yi) for yi in y_list]
419
  y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
420
  y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
421
  y_torch = y_torch.to(device).to(torch.float32)
422
+ unload_hed_model()
423
  return y_torch
424
 
425
  elif method == 'pidinet':
426
+ from .controlnet_annotator.pidinet import apply_pidinet, unload_pid_model
427
  y_list = [apply_pidinet(np.array(xi), device=device) for xi in x_list]
428
  y_list = [make_scribble(yi) for yi in y_list]
429
  y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
430
  y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
431
  y_torch = y_torch.to(device).to(torch.float32)
432
+ unload_pid_model()
433
  return y_torch
434
 
435
  elif method == 'xdog':
 
452
  raise ValueError
453
 
454
  elif type == 'seg':
455
+ assert False, "This part is broken"
456
+ # method = kwargs.pop('method', 'ufade20k')
457
+ # if method == 'ufade20k':
458
+ # from .controlnet_annotator.uniformer import apply_uniformer
459
+ # y_list = [apply_uniformer(np.array(xi), palette='ade20k', device=device) for xi in x_list]
460
+ # y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
461
+ # y_torch = y_torch.to(device).to(torch.float32)
462
+ # return y_torch
463
+
464
+ # else:
465
+ # raise ValueError