Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,387 Bytes
084ab29 849873b f001b17 849873b 084ab29 849873b 084ab29 849873b 084ab29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import tempfile
import numpy as np
import torch
import trimesh
import spaces
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# use torch hub
# zeroGPU hack from https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/9
torch.jit.script = lambda f: f
model = torch.hub.load("isl-org/ZoeDepth", "ZoeD_NK", pretrained=True).to(device).eval()
def get_intrinsics(H, W, fov=55.):
"""
Intrinsics for a pinhole camera model.
Assume central principal point.
"""
f = 0.5 * W / np.tan(0.5 * fov * np.pi / 180.0)
cx = 0.5 * W
cy = 0.5 * H
return np.array([[f, 0, cx],
[0, f, cy],
[0, 0, 1]])
def depth_to_points(depth, R=None, t=None, fov=55.):
K = get_intrinsics(depth.shape[1], depth.shape[2], fov=fov)
Kinv = np.linalg.inv(K)
if R is None:
R = np.eye(3)
if t is None:
t = np.zeros(3)
# M converts from your coordinate to PyTorch3D's coordinate system
M = np.eye(3)
M[0, 0] = -1.0
M[1, 1] = -1.0
height, width = depth.shape[1:3]
x = np.arange(width)
y = np.arange(height)
coord = np.stack(np.meshgrid(x, y), -1)
coord = np.concatenate((coord, np.ones_like(coord)[:, :, [0]]), -1) # z=1
coord = coord.astype(np.float32)
coord = coord[None] # bs, h, w, 3
D = depth[:, :, :, None, None]
pts3D_1 = D * Kinv[None, None, None, ...] @ coord[:, :, :, :, None]
# pts3D_1 live in your coordinate system. Convert them to Py3D's
pts3D_1 = M[None, None, None, ...] @ pts3D_1
# from reference to targe tviewpoint
pts3D_2 = R[None, None, None, ...] @ pts3D_1 + t[None, None, None, :, None]
return pts3D_2[:, :, :, :3, 0][0]
def create_triangles(h, w, mask=None):
"""
Reference: https://github.com/google-research/google-research/blob/e96197de06613f1b027d20328e06d69829fa5a89/infinite_nature/render_utils.py#L68
Creates mesh triangle indices from a given pixel grid size.
This function is not and need not be differentiable as triangle indices are
fixed.
Args:
h: (int) denoting the height of the image.
w: (int) denoting the width of the image.
Returns:
triangles: 2D numpy array of indices (int) with shape (2(W-1)(H-1) x 3)
"""
x, y = np.meshgrid(range(w - 1), range(h - 1))
tl = y * w + x
tr = y * w + x + 1
bl = (y + 1) * w + x
br = (y + 1) * w + x + 1
triangles = np.array([tl, bl, tr, br, tr, bl])
triangles = np.transpose(triangles, (1, 2, 0)).reshape(
((w - 1) * (h - 1) * 2, 3))
if mask is not None:
mask = mask.reshape(-1)
triangles = triangles[mask[triangles].all(1)]
return triangles
def depth_edges_mask(depth):
"""Returns a mask of edges in the depth map.
Args:
depth: 2D numpy array of shape (H, W) with dtype float32.
Returns:
mask: 2D numpy array of shape (H, W) with dtype bool.
"""
# Compute the x and y gradients of the depth map.
depth_dx, depth_dy = np.gradient(depth)
# Compute the gradient magnitude.
depth_grad = np.sqrt(depth_dx ** 2 + depth_dy ** 2)
# Compute the edge mask.
mask = depth_grad > 0.05
return mask
@spaces.GPU
def mesh_reconstruction(
masked_image: np.ndarray,
mask: np.ndarray,
remove_edges: bool = True,
fov: float = 55.,
mask_threshold: float = 25.,
):
rgb = masked_image[..., :3].transpose(2, 0, 1) / 255.
sample = torch.from_numpy(rgb).to(device).unsqueeze(0).float()
with torch.no_grad():
depth = model.infer(sample)
depth = depth.squeeze().cpu().numpy()
pts3d = depth_to_points(depth[None], fov=fov)
pts3d = pts3d.reshape(-1, 3)
pts3d = pts3d.reshape(-1, 3)
verts = pts3d.reshape(-1, 3)
rgb = rgb.transpose(1, 2, 0)
mask = mask[..., 0] > mask_threshold
edge_mask = depth_edges_mask(depth)
if remove_edges:
mask = np.logical_and(mask, ~edge_mask)
triangles = create_triangles(rgb.shape[0], rgb.shape[1], mask=mask)
colors = rgb.reshape(-1, 3)
mesh = trimesh.Trimesh(vertices=verts, faces=triangles, vertex_colors=colors)
# Save as glb tmp file (obj will look inverted in ui)
mesh_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
mesh_file_path = mesh_file.name
mesh.export(mesh_file_path)
return mesh_file_path
|