Spaces:
Runtime error
Runtime error
update for GPU usage
Browse files
annotator/hed/__init__.py
CHANGED
@@ -87,10 +87,11 @@ class SOFT_HEDdetector:
|
|
87 |
if not os.path.exists(modelpath):
|
88 |
from basicsr.utils.download_util import load_file_from_url
|
89 |
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
90 |
-
self.netNetwork = ControlNetHED_Apache2().float().
|
91 |
-
self.netNetwork.load_state_dict(torch.load(modelpath))
|
92 |
|
93 |
def __call__(self, input_image, safe=False, threshold=200):
|
|
|
94 |
assert input_image.ndim == 3
|
95 |
H, W, C = input_image.shape
|
96 |
with torch.no_grad():
|
|
|
87 |
if not os.path.exists(modelpath):
|
88 |
from basicsr.utils.download_util import load_file_from_url
|
89 |
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
90 |
+
self.netNetwork = ControlNetHED_Apache2().float().eval()
|
91 |
+
self.netNetwork.load_state_dict(torch.load(modelpath, map_location='cpu'))
|
92 |
|
93 |
def __call__(self, input_image, safe=False, threshold=200):
|
94 |
+
self.netNetwork.cuda()
|
95 |
assert input_image.ndim == 3
|
96 |
H, W, C = input_image.shape
|
97 |
with torch.no_grad():
|