JUGGHM commited on
Commit
673c19d
1 Parent(s): 56a2ba1

Update mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py

Browse files
mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py CHANGED
@@ -792,8 +792,8 @@ class RAFTDepthNormalDPT5(nn.Module):
792
  self.relu = nn.ReLU(inplace=True)
793
 
794
  def get_bins(self, bins_num):
795
- #depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cuda")
796
- depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cpu")
797
  depth_bins_vec = torch.exp(depth_bins_vec)
798
  return depth_bins_vec
799
 
@@ -848,7 +848,8 @@ class RAFTDepthNormalDPT5(nn.Module):
848
  return norm_normalize(torch.cat([normal_out, confidence], dim=1))
849
  #return norm_normalize(torch.cat([normal_out, confidence], dim=1).float())
850
 
851
- def create_mesh_grid(self, height, width, batch, device="cpu", set_buffer=True):
 
852
  y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device),
853
  torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij')
854
  meshgrid = torch.stack((x, y))
 
792
  self.relu = nn.ReLU(inplace=True)
793
 
794
  def get_bins(self, bins_num):
795
+ depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cuda")
796
+ #depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cpu")
797
  depth_bins_vec = torch.exp(depth_bins_vec)
798
  return depth_bins_vec
799
 
 
848
  return norm_normalize(torch.cat([normal_out, confidence], dim=1))
849
  #return norm_normalize(torch.cat([normal_out, confidence], dim=1).float())
850
 
851
+ #def create_mesh_grid(self, height, width, batch, device="cpu", set_buffer=True):
852
+ def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True):
853
  y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device),
854
  torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij')
855
  meshgrid = torch.stack((x, y))