andro-flock commited on
Commit
af8e933
·
verified ·
1 Parent(s): 6e98be6

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. heatmap.py +4 -1
heatmap.py CHANGED
@@ -9,6 +9,9 @@ model_type = "DPT_Large" # MiDaS v3 - Large (highest accuracy, slowest i
9
 
10
  midas = torch.hub.load("intel-isl/MiDaS", model_type)
11
  midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
 
 
 
12
 
13
  if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
14
  transform = midas_transforms.dpt_transform
@@ -18,7 +21,7 @@ else:
18
  def get_heatmap(src_img):
19
  cv2image = cv2.imread(src_img)
20
  img = cv2.cvtColor(cv2image, cv2.COLOR_BGR2RGB)
21
- input_batch = transform(img)
22
  with torch.inference_mode():
23
  prediction = midas(input_batch)
24
 
 
9
 
10
  midas = torch.hub.load("intel-isl/MiDaS", model_type)
11
  midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
12
+ # Move to CUDA if available
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ midas = midas.to(device)
15
 
16
  if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
17
  transform = midas_transforms.dpt_transform
 
21
  def get_heatmap(src_img):
22
  cv2image = cv2.imread(src_img)
23
  img = cv2.cvtColor(cv2image, cv2.COLOR_BGR2RGB)
24
+ input_batch = transform(img).to(device)
25
  with torch.inference_mode():
26
  prediction = midas(input_batch)
27