Spaces:
Paused
Paused
Upload folder using huggingface_hub
Browse files- heatmap.py +4 -1
heatmap.py
CHANGED
@@ -9,6 +9,9 @@ model_type = "DPT_Large" # MiDaS v3 - Large (highest accuracy, slowest i
|
|
9 |
|
10 |
midas = torch.hub.load("intel-isl/MiDaS", model_type)
|
11 |
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
|
|
|
|
|
|
|
12 |
|
13 |
if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
|
14 |
transform = midas_transforms.dpt_transform
|
@@ -18,7 +21,7 @@ else:
|
|
18 |
def get_heatmap(src_img):
|
19 |
cv2image = cv2.imread(src_img)
|
20 |
img = cv2.cvtColor(cv2image, cv2.COLOR_BGR2RGB)
|
21 |
-
input_batch = transform(img)
|
22 |
with torch.inference_mode():
|
23 |
prediction = midas(input_batch)
|
24 |
|
|
|
9 |
|
10 |
midas = torch.hub.load("intel-isl/MiDaS", model_type)
|
11 |
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
|
12 |
+
# Move to CUDA if available
|
13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
midas = midas.to(device)
|
15 |
|
16 |
if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
|
17 |
transform = midas_transforms.dpt_transform
|
|
|
21 |
def get_heatmap(src_img):
|
22 |
cv2image = cv2.imread(src_img)
|
23 |
img = cv2.cvtColor(cv2image, cv2.COLOR_BGR2RGB)
|
24 |
+
input_batch = transform(img).to(device)
|
25 |
with torch.inference_mode():
|
26 |
prediction = midas(input_batch)
|
27 |
|