Spaces:
Paused
Paused
# intel-isl/MiDaS | |
import cv2 | |
import torch | |
import torch.nn.functional as F | |
model_type = "DPT_Large" # MiDaS v3 - Large (highest accuracy, slowest inference speed) | |
# model_type = "DPT_Hybrid" # MiDaS v3 - Hybrid (medium accuracy, medium inference speed) | |
# model_type = "MiDaS_small" # MiDaS v2.1 - Small (lowest accuracy, highest inference speed) | |
midas = torch.hub.load("intel-isl/MiDaS", model_type) | |
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") | |
# Move to CUDA if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
midas = midas.to(device) | |
if model_type == "DPT_Large" or model_type == "DPT_Hybrid": | |
transform = midas_transforms.dpt_transform | |
else: | |
transform = midas_transforms.small_transform | |
def get_heatmap(src_img): | |
cv2image = cv2.imread(src_img) | |
img = cv2.cvtColor(cv2image, cv2.COLOR_BGR2RGB) | |
input_batch = transform(img).to(device) | |
with torch.inference_mode(): | |
prediction = midas(input_batch) | |
prediction = F.interpolate( | |
prediction.unsqueeze(1), | |
size=img.shape[:2], | |
mode="bicubic", | |
align_corners=False, | |
).squeeze() | |
output = prediction.cpu().numpy() | |
return output |