File size: 1,253 Bytes
0d66750
 
 
 
 
 
 
 
 
 
 
af8e933
 
 
0d66750
 
 
 
 
 
 
 
 
af8e933
0d66750
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# intel-isl/MiDaS
import cv2
import torch
import torch.nn.functional as F

model_type = "DPT_Large"     # MiDaS v3 - Large     (highest accuracy, slowest inference speed)
# model_type = "DPT_Hybrid"   # MiDaS v3 - Hybrid    (medium accuracy, medium inference speed)
# model_type = "MiDaS_small"  # MiDaS v2.1 - Small   (lowest accuracy, highest inference speed)

midas = torch.hub.load("intel-isl/MiDaS", model_type)
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
# Move to CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
midas = midas.to(device)

if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
    transform = midas_transforms.dpt_transform
else:
    transform = midas_transforms.small_transform

def get_heatmap(src_img):
    cv2image = cv2.imread(src_img)
    img = cv2.cvtColor(cv2image, cv2.COLOR_BGR2RGB)
    input_batch = transform(img).to(device)
    with torch.inference_mode():
        prediction = midas(input_batch)

    prediction = F.interpolate(
        prediction.unsqueeze(1),
        size=img.shape[:2],
        mode="bicubic",
        align_corners=False,
    ).squeeze()
    output = prediction.cpu().numpy()
    return output