Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
try: | |
from ...MiDaS.midas.dpt_depth import DPTDepthModel | |
except ImportError: | |
print('Please pull the MiDaS submodule via "git submodule update --init --recursive"!') | |
class MidasDetector(nn.Module): | |
def __init__(self, model_path="./models/dpt_hybrid-midas-501f0c75.pt"): | |
super().__init__() | |
self.model = DPTDepthModel(path=model_path, backbone="vitb_rn50_384", non_negative=True) | |
self.model.requires_grad_(False) | |
self.model.eval() | |
def dtype(self): | |
return next(self.parameters()).dtype | |
def device(self): | |
return next(self.parameters()).device | |
def forward(self, images: torch.Tensor): | |
""" | |
Input: [b, c, h, w] | |
""" | |
return self.model(images) | |