nowsyn commited on
Commit
caec4eb
·
1 Parent(s): 5d9aca7

update for GPU usage

Browse files
Files changed (1) hide show
  1. annotator/hed/__init__.py +3 -2
annotator/hed/__init__.py CHANGED
@@ -87,10 +87,11 @@ class SOFT_HEDdetector:
87
  if not os.path.exists(modelpath):
88
  from basicsr.utils.download_util import load_file_from_url
89
  load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
90
- self.netNetwork = ControlNetHED_Apache2().float().cuda().eval()
91
- self.netNetwork.load_state_dict(torch.load(modelpath))
92
 
93
  def __call__(self, input_image, safe=False, threshold=200):
 
94
  assert input_image.ndim == 3
95
  H, W, C = input_image.shape
96
  with torch.no_grad():
 
87
  if not os.path.exists(modelpath):
88
  from basicsr.utils.download_util import load_file_from_url
89
  load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
90
+ self.netNetwork = ControlNetHED_Apache2().float().eval()
91
+ self.netNetwork.load_state_dict(torch.load(modelpath, map_location='cpu'))
92
 
93
  def __call__(self, input_image, safe=False, threshold=200):
94
+ self.netNetwork.cuda()
95
  assert input_image.ndim == 3
96
  H, W, C = input_image.shape
97
  with torch.no_grad():