vumichien commited on
Commit
20d2ab6
·
1 Parent(s): 5a42864

Update annotator/hed/__init__.py

Browse files
Files changed (1) hide show
  1. annotator/hed/__init__.py +9 -2
annotator/hed/__init__.py CHANGED
@@ -100,13 +100,20 @@ class HEDdetector:
100
  if not os.path.exists(modelpath):
101
  from basicsr.utils.download_util import load_file_from_url
102
  load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
103
- self.netNetwork = Network(modelpath).cuda().eval()
 
 
 
 
104
 
105
  def __call__(self, input_image):
106
  assert input_image.ndim == 3
107
  input_image = input_image[:, :, ::-1].copy()
108
  with torch.no_grad():
109
- image_hed = torch.from_numpy(input_image).float().cuda()
 
 
 
110
  image_hed = image_hed / 255.0
111
  image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
112
  edge = self.netNetwork(image_hed)[0]
 
100
  if not os.path.exists(modelpath):
101
  from basicsr.utils.download_util import load_file_from_url
102
  load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
103
+ if torch.cuda.is_available():
104
+ self.netNetwork = Network(modelpath).cuda().eval()
105
+ else:
106
+ self.netNetwork = Network(modelpath).eval()
107
+
108
 
109
  def __call__(self, input_image):
110
  assert input_image.ndim == 3
111
  input_image = input_image[:, :, ::-1].copy()
112
  with torch.no_grad():
113
+ if torch.cuda.is_available():
114
+ image_hed = torch.from_numpy(input_image).float().cuda()
115
+ else:
116
+ image_hed = torch.from_numpy(input_image).float()
117
  image_hed = image_hed / 255.0
118
  image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
119
  edge = self.netNetwork(image_hed)[0]