Spaces:
Sleeping
Sleeping
File size: 828 Bytes
d16b52d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import torch
import torch.nn as nn
try:
from ...MiDaS.midas.dpt_depth import DPTDepthModel
except ImportError:
print('Please pull the MiDaS submodule via "git submodule update --init --recursive"!')
class MidasDetector(nn.Module):
def __init__(self, model_path="./models/dpt_hybrid-midas-501f0c75.pt"):
super().__init__()
self.model = DPTDepthModel(path=model_path, backbone="vitb_rn50_384", non_negative=True)
self.model.requires_grad_(False)
self.model.eval()
@property
def dtype(self):
return next(self.parameters()).dtype
@property
def device(self):
return next(self.parameters()).device
@torch.no_grad()
def forward(self, images: torch.Tensor):
"""
Input: [b, c, h, w]
"""
return self.model(images)
|