Upload 19 files
Browse files- __init__.py +0 -0
- app.py +49 -0
- checkpoints/epoch=99-step=6000.ckpt +3 -0
- files/bunny_n1_hi_50.obj +0 -0
- files/child_n2_80.obj +0 -0
- files/eight_n3_70.obj +0 -0
- models/SAP/__init__.py +0 -0
- models/SAP/__pycache__/__init__.cpython-39.pyc +0 -0
- models/SAP/__pycache__/dpsr.cpython-39.pyc +0 -0
- models/SAP/__pycache__/model.cpython-39.pyc +0 -0
- models/SAP/__pycache__/utils.cpython-39.pyc +0 -0
- models/SAP/dpsr.py +65 -0
- models/SAP/model.py +129 -0
- models/SAP/utils.py +526 -0
- models/__init__.py +0 -0
- models/__pycache__/__init__.cpython-39.pyc +0 -0
- models/__pycache__/model.cpython-39.pyc +0 -0
- models/model.py +181 -0
- requirements.txt +298 -0
__init__.py
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
import torch as th
|
6 |
+
import open3d as o3d
|
7 |
+
import numpy as np
|
8 |
+
import trimesh as tm
|
9 |
+
|
10 |
+
from models.model import Model
|
11 |
+
|
12 |
+
model = Model()
|
13 |
+
ckpg = th.load("./checkpoints/epoch=99-step=6000.ckpt")
|
14 |
+
model.load_state_dict(ckpg["state_dict"])
|
15 |
+
|
16 |
+
|
17 |
+
def process_mesh(mesh_file_name):
|
18 |
+
|
19 |
+
mesh = tm.load_mesh(mesh_file_name)
|
20 |
+
|
21 |
+
v = th.tensor(mesh.vertices, dtype=th.float)
|
22 |
+
n = th.tensor(mesh.vertex_normals, dtype=th.float)
|
23 |
+
|
24 |
+
with th.no_grad():
|
25 |
+
v, f, n, _ = model(v.unsqueeze(0), n.unsqueeze(0))
|
26 |
+
|
27 |
+
mesh = tm.Trimesh(vertices=v.squeeze(0),
|
28 |
+
faces=f.squeeze(0),
|
29 |
+
vertex_normals=n.squeeze(0))
|
30 |
+
obj_path = "./sample.obj"
|
31 |
+
mesh.export(obj_path)
|
32 |
+
|
33 |
+
return obj_path
|
34 |
+
|
35 |
+
|
36 |
+
demo = gr.Interface(
|
37 |
+
fn=process_mesh,
|
38 |
+
inputs=gr.Model3D(),
|
39 |
+
outputs=gr.Model3D(
|
40 |
+
clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
|
41 |
+
examples=[
|
42 |
+
[os.path.join(os.path.dirname(__file__), "files\\bunny_n1_hi_50.obj")],
|
43 |
+
[os.path.join(os.path.dirname(__file__), "files\\child_n2_80.obj")],
|
44 |
+
[os.path.join(os.path.dirname(__file__), "files\\eight_n3_70.obj")],
|
45 |
+
],
|
46 |
+
)
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
demo.launch()
|
checkpoints/epoch=99-step=6000.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d3a025dfa88edbf34bf7cf2b69c554cb01ab7d61b4d8cc699a2a6753e14dbdea
|
3 |
+
size 4308343
|
files/bunny_n1_hi_50.obj
ADDED
The diff for this file is too large to render.
See raw diff
|
|
files/child_n2_80.obj
ADDED
The diff for this file is too large to render.
See raw diff
|
|
files/eight_n3_70.obj
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/SAP/__init__.py
ADDED
File without changes
|
models/SAP/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (133 Bytes). View file
|
|
models/SAP/__pycache__/dpsr.cpython-39.pyc
ADDED
Binary file (2.47 kB). View file
|
|
models/SAP/__pycache__/model.cpython-39.pyc
ADDED
Binary file (4.08 kB). View file
|
|
models/SAP/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (15 kB). View file
|
|
models/SAP/dpsr.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from .utils import spec_gaussian_filter, fftfreqs, img, grid_interp, point_rasterize
|
4 |
+
import numpy as np
|
5 |
+
import torch.fft
|
6 |
+
|
7 |
+
class DPSR(nn.Module):
|
8 |
+
def __init__(self, res, sig=10, scale=True, shift=True):
|
9 |
+
"""
|
10 |
+
:param res: tuple of output field resolution. eg., (128,128)
|
11 |
+
:param sig: degree of gaussian smoothing
|
12 |
+
"""
|
13 |
+
super(DPSR, self).__init__()
|
14 |
+
self.res = res
|
15 |
+
self.sig = sig
|
16 |
+
self.dim = len(res)
|
17 |
+
self.denom = np.prod(res)
|
18 |
+
G = spec_gaussian_filter(res=res, sig=sig).float()
|
19 |
+
# self.G.requires_grad = False # True, if we also make sig a learnable parameter
|
20 |
+
self.omega = fftfreqs(res, dtype=torch.float32)
|
21 |
+
self.scale = scale
|
22 |
+
self.shift = shift
|
23 |
+
self.register_buffer("G", G)
|
24 |
+
|
25 |
+
def forward(self, V, N):
|
26 |
+
"""
|
27 |
+
:param V: (batch, nv, 2 or 3) tensor for point cloud coordinates
|
28 |
+
:param N: (batch, nv, 2 or 3) tensor for point normals
|
29 |
+
:return phi: (batch, res, res, ...) tensor of output indicator function field
|
30 |
+
"""
|
31 |
+
assert(V.shape == N.shape) # [b, nv, ndims]
|
32 |
+
ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2]
|
33 |
+
|
34 |
+
ras_s = torch.fft.rfftn(ras_p, dim=(2,3,4))
|
35 |
+
ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+1))+[self.dim+1, 1]))
|
36 |
+
N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1]
|
37 |
+
|
38 |
+
omega = fftfreqs(self.res, dtype=torch.float32).unsqueeze(-1) # [dim0, dim1, dim2/2+1, n_dim, 1]
|
39 |
+
omega *= 2 * np.pi # normalize frequencies
|
40 |
+
omega = omega.to(V.device)
|
41 |
+
|
42 |
+
DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2)
|
43 |
+
|
44 |
+
Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1]
|
45 |
+
Phi = DivN / (Lap+1e-6) # [b, dim0, dim1, dim2/2+1, 2]
|
46 |
+
Phi = Phi.permute(*tuple([list(range(1,self.dim+2)) + [0]])) # [dim0, dim1, dim2/2+1, 2, b]
|
47 |
+
Phi[tuple([0] * self.dim)] = 0
|
48 |
+
Phi = Phi.permute(*tuple([[self.dim+1] + list(range(self.dim+1))])) # [b, dim0, dim1, dim2/2+1, 2]
|
49 |
+
|
50 |
+
phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1,2,3))
|
51 |
+
|
52 |
+
if self.shift or self.scale:
|
53 |
+
# ensure values at points are zero
|
54 |
+
fv = grid_interp(phi.unsqueeze(-1), V, batched=True).squeeze(-1) # [b, nv]
|
55 |
+
if self.shift: # offset points to have mean of 0
|
56 |
+
offset = torch.mean(fv, dim=-1) # [b,]
|
57 |
+
phi -= offset.view(*tuple([-1] + [1] * self.dim))
|
58 |
+
|
59 |
+
phi = phi.permute(*tuple([list(range(1,self.dim+1)) + [0]]))
|
60 |
+
fv0 = phi[tuple([0] * self.dim)] # [b,]
|
61 |
+
phi = phi.permute(*tuple([[self.dim] + list(range(self.dim))]))
|
62 |
+
|
63 |
+
if self.scale:
|
64 |
+
phi = -phi / torch.abs(fv0.view(*tuple([-1]+[1] * self.dim))) *0.5
|
65 |
+
return phi
|
models/SAP/model.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import time
|
4 |
+
from .utils import point_rasterize, grid_interp, mc_from_psr, \
|
5 |
+
calc_inters_points
|
6 |
+
from .dpsr import DPSR
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
class PSR2Mesh(torch.autograd.Function):
|
10 |
+
@staticmethod
|
11 |
+
def forward(ctx, psr_grid):
|
12 |
+
"""
|
13 |
+
In the forward pass we receive a Tensor containing the input and return
|
14 |
+
a Tensor containing the output. ctx is a context object that can be used
|
15 |
+
to stash information for backward computation. You can cache arbitrary
|
16 |
+
objects for use in the backward pass using the ctx.save_for_backward method.
|
17 |
+
"""
|
18 |
+
verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
|
19 |
+
verts = verts.unsqueeze(0)
|
20 |
+
faces = faces.unsqueeze(0)
|
21 |
+
normals = normals.unsqueeze(0)
|
22 |
+
|
23 |
+
res = torch.tensor(psr_grid.detach().shape[2])
|
24 |
+
ctx.save_for_backward(verts, normals, res)
|
25 |
+
|
26 |
+
return verts, faces, normals
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals):
|
30 |
+
"""
|
31 |
+
In the backward pass we receive a Tensor containing the gradient of the loss
|
32 |
+
with respect to the output, and we need to compute the gradient of the loss
|
33 |
+
with respect to the input.
|
34 |
+
"""
|
35 |
+
vert_pts, normals, res = ctx.saved_tensors
|
36 |
+
res = (res.item(), res.item(), res.item())
|
37 |
+
# matrix multiplication between dL/dV and dV/dPSR
|
38 |
+
# dV/dPSR = - normals
|
39 |
+
grad_vert = torch.matmul(dL_dVertex.permute(1, 0, 2), -normals.permute(1, 2, 0))
|
40 |
+
grad_grid = point_rasterize(vert_pts, grad_vert.permute(1, 0, 2), res) # b x 1 x res x res x res
|
41 |
+
|
42 |
+
return grad_grid
|
43 |
+
|
44 |
+
class PSR2SurfacePoints(torch.autograd.Function):
|
45 |
+
@staticmethod
|
46 |
+
def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample):
|
47 |
+
verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
|
48 |
+
verts = verts * 2. - 1. # within the range of [-1, 1]
|
49 |
+
|
50 |
+
|
51 |
+
p_all, n_all, mask_all = [], [], []
|
52 |
+
|
53 |
+
for i in range(len(poses)):
|
54 |
+
pose = poses[i]
|
55 |
+
if mask_sample is not None:
|
56 |
+
p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size, mask_gt=mask_sample[i])
|
57 |
+
else:
|
58 |
+
p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size)
|
59 |
+
|
60 |
+
n_inters = grid_interp(psr_grad[None], (p_inters[None].detach() + 1) / 2).squeeze()
|
61 |
+
p_all.append(p_inters)
|
62 |
+
n_all.append(n_inters)
|
63 |
+
mask_all.append(mask)
|
64 |
+
p_inters_all = torch.cat(p_all, dim=0)
|
65 |
+
n_inters_all = torch.cat(n_all, dim=0)
|
66 |
+
mask_visible = torch.stack(mask_all, dim=0)
|
67 |
+
|
68 |
+
|
69 |
+
res = torch.tensor(psr_grid.detach().shape[2])
|
70 |
+
ctx.save_for_backward(p_inters_all, n_inters_all, res)
|
71 |
+
|
72 |
+
return p_inters_all, mask_visible
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def backward(ctx, dL_dp, dL_dmask):
|
76 |
+
pts, pts_n, res = ctx.saved_tensors
|
77 |
+
res = (res.item(), res.item(), res.item())
|
78 |
+
|
79 |
+
# grad from the p_inters via MLP renderer
|
80 |
+
grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None])
|
81 |
+
grad_grid_pts = point_rasterize((pts[None]+1)/2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res
|
82 |
+
|
83 |
+
return grad_grid_pts, None, None, None, None, None
|
84 |
+
|
85 |
+
|
86 |
+
# Resnet Blocks from https://github.com/autonomousvision/shape_as_points/blob/12757682f1075d83738b52f96747463b77343caf/src/network/utils.py
|
87 |
+
class ResnetBlockFC(nn.Module):
|
88 |
+
''' Fully connected ResNet Block class.
|
89 |
+
Args:
|
90 |
+
size_in (int): input dimension
|
91 |
+
size_out (int): output dimension
|
92 |
+
size_h (int): hidden dimension
|
93 |
+
'''
|
94 |
+
|
95 |
+
def __init__(self, size_in, size_out=None, size_h=None, siren=False):
|
96 |
+
super().__init__()
|
97 |
+
# Attributes
|
98 |
+
if size_out is None:
|
99 |
+
size_out = size_in
|
100 |
+
|
101 |
+
if size_h is None:
|
102 |
+
size_h = min(size_in, size_out)
|
103 |
+
|
104 |
+
self.size_in = size_in
|
105 |
+
self.size_h = size_h
|
106 |
+
self.size_out = size_out
|
107 |
+
# Submodules
|
108 |
+
self.fc_0 = nn.Linear(size_in, size_h)
|
109 |
+
self.fc_1 = nn.Linear(size_h, size_out)
|
110 |
+
self.actvn = nn.ReLU()
|
111 |
+
|
112 |
+
if size_in == size_out:
|
113 |
+
self.shortcut = None
|
114 |
+
else:
|
115 |
+
self.shortcut = nn.Linear(size_in, size_out, bias=False)
|
116 |
+
# Initialization
|
117 |
+
nn.init.zeros_(self.fc_1.weight)
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
net = self.fc_0(self.actvn(x))
|
121 |
+
dx = self.fc_1(self.actvn(net))
|
122 |
+
|
123 |
+
if self.shortcut is not None:
|
124 |
+
x_s = self.shortcut(x)
|
125 |
+
else:
|
126 |
+
x_s = x
|
127 |
+
|
128 |
+
return x_s + dx
|
129 |
+
|
models/SAP/utils.py
ADDED
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import io, os, logging, urllib
|
3 |
+
import yaml
|
4 |
+
import trimesh
|
5 |
+
import imageio
|
6 |
+
import numbers
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
from collections import OrderedDict
|
10 |
+
from plyfile import PlyData
|
11 |
+
from torch import nn
|
12 |
+
from torch.nn import functional as F
|
13 |
+
from torch.utils import model_zoo
|
14 |
+
from skimage import measure, img_as_float32
|
15 |
+
from igl import adjacency_matrix, connected_components
|
16 |
+
|
17 |
+
##################################################
|
18 |
+
# Below are functions for DPSR
|
19 |
+
|
20 |
+
def fftfreqs(res, dtype=torch.float32, exact=True):
|
21 |
+
"""
|
22 |
+
Helper function to return frequency tensors
|
23 |
+
:param res: n_dims int tuple of number of frequency modes
|
24 |
+
:return:
|
25 |
+
"""
|
26 |
+
|
27 |
+
n_dims = len(res)
|
28 |
+
freqs = []
|
29 |
+
for dim in range(n_dims - 1):
|
30 |
+
r_ = res[dim]
|
31 |
+
freq = np.fft.fftfreq(r_, d=1/r_)
|
32 |
+
freqs.append(torch.tensor(freq, dtype=dtype))
|
33 |
+
r_ = res[-1]
|
34 |
+
if exact:
|
35 |
+
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype))
|
36 |
+
else:
|
37 |
+
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype))
|
38 |
+
omega = torch.meshgrid(freqs)
|
39 |
+
omega = list(omega)
|
40 |
+
omega = torch.stack(omega, dim=-1)
|
41 |
+
|
42 |
+
return omega
|
43 |
+
|
44 |
+
def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag)
|
45 |
+
"""
|
46 |
+
multiply tensor x by i ** deg
|
47 |
+
"""
|
48 |
+
deg %= 4
|
49 |
+
if deg == 0:
|
50 |
+
res = x
|
51 |
+
elif deg == 1:
|
52 |
+
res = x[..., [1, 0]]
|
53 |
+
res[..., 0] = -res[..., 0]
|
54 |
+
elif deg == 2:
|
55 |
+
res = -x
|
56 |
+
elif deg == 3:
|
57 |
+
res = x[..., [1, 0]]
|
58 |
+
res[..., 1] = -res[..., 1]
|
59 |
+
return res
|
60 |
+
|
61 |
+
def spec_gaussian_filter(res, sig):
|
62 |
+
omega = fftfreqs(res, dtype=torch.float64) # [dim0, dim1, dim2, d]
|
63 |
+
dis = torch.sqrt(torch.sum(omega ** 2, dim=-1))
|
64 |
+
filter_ = torch.exp(-0.5*((sig*2*dis/res[0])**2)).unsqueeze(-1).unsqueeze(-1)
|
65 |
+
filter_.requires_grad = False
|
66 |
+
|
67 |
+
return filter_
|
68 |
+
|
69 |
+
def grid_interp(grid, pts, batched=True):
|
70 |
+
"""
|
71 |
+
:param grid: tensor of shape (batch, *size, in_features)
|
72 |
+
:param pts: tensor of shape (batch, num_points, dim) within range (0, 1)
|
73 |
+
:return values at query points
|
74 |
+
"""
|
75 |
+
if not batched:
|
76 |
+
grid = grid.unsqueeze(0)
|
77 |
+
pts = pts.unsqueeze(0)
|
78 |
+
dim = pts.shape[-1]
|
79 |
+
bs = grid.shape[0]
|
80 |
+
size = torch.tensor(grid.shape[1:-1]).to(grid.device).type(pts.dtype)
|
81 |
+
cubesize = 1.0 / size
|
82 |
+
|
83 |
+
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
|
84 |
+
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
|
85 |
+
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
|
86 |
+
tmp = torch.tensor([0,1],dtype=torch.long)
|
87 |
+
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
|
88 |
+
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
|
89 |
+
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
|
90 |
+
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
|
91 |
+
ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
|
92 |
+
# latent code on neighbor nodes
|
93 |
+
if dim == 2:
|
94 |
+
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1]] # (batch, num_points, 2**dim, in_features)
|
95 |
+
else:
|
96 |
+
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2]] # (batch, num_points, 2**dim, in_features)
|
97 |
+
|
98 |
+
# weights of neighboring nodes
|
99 |
+
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
|
100 |
+
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
|
101 |
+
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
|
102 |
+
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
103 |
+
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
104 |
+
pos_ = pos_.type(pts.dtype)
|
105 |
+
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
|
106 |
+
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
|
107 |
+
query_values = torch.sum(lat * weights.unsqueeze(-1), dim=-2) # (batch, num_points, in_features)
|
108 |
+
if not batched:
|
109 |
+
query_values = query_values.squeeze(0)
|
110 |
+
|
111 |
+
return query_values
|
112 |
+
|
113 |
+
def scatter_to_grid(inds, vals, size):
|
114 |
+
"""
|
115 |
+
Scatter update values into empty tensor of size size.
|
116 |
+
:param inds: (#values, dims)
|
117 |
+
:param vals: (#values)
|
118 |
+
:param size: tuple for size. len(size)=dims
|
119 |
+
"""
|
120 |
+
dims = inds.shape[1]
|
121 |
+
assert(inds.shape[0] == vals.shape[0])
|
122 |
+
assert(len(size) == dims)
|
123 |
+
dev = vals.device
|
124 |
+
# result = torch.zeros(*size).view(-1).to(dev).type(vals.dtype) # flatten
|
125 |
+
# # flatten inds
|
126 |
+
result = torch.zeros(*size, device=dev).view(-1).type(vals.dtype) # flatten
|
127 |
+
# flatten inds
|
128 |
+
fac = [np.prod(size[i+1:]) for i in range(len(size)-1)] + [1]
|
129 |
+
fac = torch.tensor(fac, device=dev).type(inds.dtype)
|
130 |
+
inds_fold = torch.sum(inds*fac, dim=-1) # [#values,]
|
131 |
+
result.scatter_add_(0, inds_fold, vals)
|
132 |
+
result = result.view(*size)
|
133 |
+
return result
|
134 |
+
|
135 |
+
def point_rasterize(pts, vals, size):
|
136 |
+
"""
|
137 |
+
:param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1)
|
138 |
+
:param vals: point values, tensor of shape (batch, num_points, features)
|
139 |
+
:param size: len(size)=dim tuple for grid size
|
140 |
+
:return rasterized values (batch, features, res0, res1, res2)
|
141 |
+
"""
|
142 |
+
dim = pts.shape[-1]
|
143 |
+
assert(pts.shape[:2] == vals.shape[:2])
|
144 |
+
assert(pts.shape[2] == dim)
|
145 |
+
size_list = list(size)
|
146 |
+
size = torch.tensor(size).to(pts.device).float()
|
147 |
+
cubesize = 1.0 / size
|
148 |
+
bs = pts.shape[0]
|
149 |
+
nf = vals.shape[-1]
|
150 |
+
npts = pts.shape[1]
|
151 |
+
dev = pts.device
|
152 |
+
|
153 |
+
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
|
154 |
+
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
|
155 |
+
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
|
156 |
+
tmp = torch.tensor([0,1],dtype=torch.long)
|
157 |
+
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
|
158 |
+
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
|
159 |
+
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
|
160 |
+
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
|
161 |
+
# ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
|
162 |
+
ind_b = torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
|
163 |
+
|
164 |
+
# weights of neighboring nodes
|
165 |
+
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
|
166 |
+
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
|
167 |
+
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
|
168 |
+
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
169 |
+
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
|
170 |
+
pos_ = pos_.type(pts.dtype)
|
171 |
+
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
|
172 |
+
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
|
173 |
+
|
174 |
+
ind_b = ind_b.unsqueeze(-1).unsqueeze(-1) # (batch, num_points, 2**dim, 1, 1)
|
175 |
+
ind_n = ind_n.unsqueeze(-2) # (batch, num_points, 2**dim, 1, dim)
|
176 |
+
ind_f = torch.arange(nf, device=dev).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1)
|
177 |
+
# ind_f = torch.arange(nf).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1)
|
178 |
+
|
179 |
+
ind_b = ind_b.expand(bs, npts, 2**dim, nf, 1)
|
180 |
+
ind_n = ind_n.expand(bs, npts, 2**dim, nf, dim).to(dev)
|
181 |
+
ind_f = ind_f.expand(bs, npts, 2**dim, nf, 1)
|
182 |
+
inds = torch.cat([ind_b, ind_f, ind_n], dim=-1) # (batch, num_points, 2**dim, nf, 1+1+dim)
|
183 |
+
|
184 |
+
# weighted values
|
185 |
+
vals = weights.unsqueeze(-1) * vals.unsqueeze(-2) # (batch, num_points, 2**dim, nf)
|
186 |
+
|
187 |
+
inds = inds.view(-1, dim+2).permute(1, 0).long() # (1+dim+1, bs*npts*2**dim*nf)
|
188 |
+
vals = vals.reshape(-1) # (bs*npts*2**dim*nf)
|
189 |
+
tensor_size = [bs, nf] + size_list
|
190 |
+
raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf] + size_list)
|
191 |
+
|
192 |
+
return raster # [batch, nf, res, res, res]
|
193 |
+
|
194 |
+
|
195 |
+
|
196 |
+
##################################################
|
197 |
+
# Below are the utilization functions in general
|
198 |
+
|
199 |
+
class AverageMeter(object):
|
200 |
+
"""Computes and stores the average and current value"""
|
201 |
+
def __init__(self):
|
202 |
+
self.reset()
|
203 |
+
|
204 |
+
def reset(self):
|
205 |
+
self.val = 0
|
206 |
+
self.n = 0
|
207 |
+
self.avg = 0
|
208 |
+
self.sum = 0
|
209 |
+
self.count = 0
|
210 |
+
|
211 |
+
def update(self, val, n=1):
|
212 |
+
self.val = val
|
213 |
+
self.n = n
|
214 |
+
self.sum += val * n
|
215 |
+
self.count += n
|
216 |
+
self.avg = self.sum / self.count
|
217 |
+
|
218 |
+
@property
|
219 |
+
def valcavg(self):
|
220 |
+
return self.val.sum().item() / (self.n != 0).sum().item()
|
221 |
+
|
222 |
+
@property
|
223 |
+
def avgcavg(self):
|
224 |
+
return self.avg.sum().item() / (self.count != 0).sum().item()
|
225 |
+
|
226 |
+
def load_model_manual(state_dict, model):
|
227 |
+
new_state_dict = OrderedDict()
|
228 |
+
is_model_parallel = isinstance(model, torch.nn.DataParallel)
|
229 |
+
for k, v in state_dict.items():
|
230 |
+
if k.startswith('module.') != is_model_parallel:
|
231 |
+
if k.startswith('module.'):
|
232 |
+
# remove module
|
233 |
+
k = k[7:]
|
234 |
+
else:
|
235 |
+
# add module
|
236 |
+
k = 'module.' + k
|
237 |
+
|
238 |
+
new_state_dict[k]=v
|
239 |
+
|
240 |
+
model.load_state_dict(new_state_dict)
|
241 |
+
|
242 |
+
def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0):
|
243 |
+
'''
|
244 |
+
Run marching cubes from PSR grid
|
245 |
+
'''
|
246 |
+
batch_size = psr_grid.shape[0]
|
247 |
+
s = psr_grid.shape[-1] # size of psr_grid
|
248 |
+
psr_grid_numpy = psr_grid.squeeze().detach().cpu().numpy()
|
249 |
+
|
250 |
+
if batch_size>1:
|
251 |
+
verts, faces, normals = [], [], []
|
252 |
+
for i in range(batch_size):
|
253 |
+
verts_cur, faces_cur, normals_cur, values = measure.marching_cubes(psr_grid_numpy[i], level=0)
|
254 |
+
verts.append(verts_cur)
|
255 |
+
faces.append(faces_cur)
|
256 |
+
normals.append(normals_cur)
|
257 |
+
verts = np.stack(verts, axis = 0)
|
258 |
+
faces = np.stack(faces, axis = 0)
|
259 |
+
normals = np.stack(normals, axis = 0)
|
260 |
+
else:
|
261 |
+
try:
|
262 |
+
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy, level=zero_level)
|
263 |
+
except:
|
264 |
+
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy)
|
265 |
+
if real_scale:
|
266 |
+
verts = verts / (s-1) # scale to range [0, 1]
|
267 |
+
else:
|
268 |
+
verts = verts / s # scale to range [0, 1)
|
269 |
+
|
270 |
+
if pytorchify:
|
271 |
+
device = psr_grid.device
|
272 |
+
verts = torch.Tensor(np.ascontiguousarray(verts)).to(device)
|
273 |
+
faces = torch.Tensor(np.ascontiguousarray(faces)).to(device)
|
274 |
+
normals = torch.Tensor(np.ascontiguousarray(-normals)).to(device)
|
275 |
+
|
276 |
+
return verts, faces, normals
|
277 |
+
|
278 |
+
def calc_inters_points(verts, faces, pose, img_size, mask_gt=None):
|
279 |
+
verts = verts.squeeze()
|
280 |
+
faces = faces.squeeze()
|
281 |
+
pix_to_face, w, mask = mesh_rasterization(verts, faces, pose, img_size)
|
282 |
+
if mask_gt is not None:
|
283 |
+
#! only evaluate within the intersection
|
284 |
+
mask = mask & mask_gt
|
285 |
+
# find 3D points intesected on the mesh
|
286 |
+
if True:
|
287 |
+
w_masked = w[mask]
|
288 |
+
f_p = faces[pix_to_face[mask]].long() # cooresponding faces for each pixel
|
289 |
+
# corresponding vertices for p_closest
|
290 |
+
v_a, v_b, v_c = verts[f_p[..., 0]], verts[f_p[..., 1]], verts[f_p[..., 2]]
|
291 |
+
|
292 |
+
# calculate the intersection point of each pixel and the mesh
|
293 |
+
p_inters = w_masked[..., 0, None] * v_a + \
|
294 |
+
w_masked[..., 1, None] * v_b + \
|
295 |
+
w_masked[..., 2, None] * v_c
|
296 |
+
else:
|
297 |
+
# backproject ndc to world coordinates using z-buffer
|
298 |
+
W, H = img_size[1], img_size[0]
|
299 |
+
xy = uv.to(mask.device)[mask]
|
300 |
+
x_ndc = 1 - (2*xy[:, 0]) / (W - 1)
|
301 |
+
y_ndc = 1 - (2*xy[:, 1]) / (H - 1)
|
302 |
+
z = zbuf.squeeze().reshape(H * W)[mask]
|
303 |
+
xy_depth = torch.stack((x_ndc, y_ndc, z), dim=1)
|
304 |
+
|
305 |
+
p_inters = pose.unproject_points(xy_depth, world_coordinates=True)
|
306 |
+
|
307 |
+
# if there are outlier points, we should remove it
|
308 |
+
if (p_inters.max()>1) | (p_inters.min()<-1):
|
309 |
+
mask_bound = (p_inters>=-1) & (p_inters<=1)
|
310 |
+
mask_bound = (mask_bound.sum(dim=-1)==3)
|
311 |
+
mask[mask==True] = mask_bound
|
312 |
+
p_inters = p_inters[mask_bound]
|
313 |
+
print('!!!!!find outlier!')
|
314 |
+
|
315 |
+
return p_inters, mask, f_p, w_masked
|
316 |
+
|
317 |
+
def mesh_rasterization(verts, faces, pose, img_size):
|
318 |
+
'''
|
319 |
+
Use PyTorch3D to rasterize the mesh given a camera
|
320 |
+
'''
|
321 |
+
transformed_v = pose.transform_points(verts.detach()) # world -> ndc coordinate system
|
322 |
+
if isinstance(pose, PerspectiveCameras):
|
323 |
+
transformed_v[..., 2] = 1/transformed_v[..., 2]
|
324 |
+
# find p_closest on mesh of each pixel via rasterization
|
325 |
+
transformed_mesh = Meshes(verts=[transformed_v], faces=[faces])
|
326 |
+
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
|
327 |
+
transformed_mesh,
|
328 |
+
image_size=img_size,
|
329 |
+
blur_radius=0,
|
330 |
+
faces_per_pixel=1,
|
331 |
+
perspective_correct=False
|
332 |
+
)
|
333 |
+
pix_to_face = pix_to_face.reshape(1, -1) # B x reso x reso -> B x (reso x reso)
|
334 |
+
mask = pix_to_face.clone() != -1
|
335 |
+
mask = mask.squeeze()
|
336 |
+
pix_to_face = pix_to_face.squeeze()
|
337 |
+
w = bary_coords.reshape(-1, 3)
|
338 |
+
|
339 |
+
return pix_to_face, w, mask
|
340 |
+
|
341 |
+
def verts_on_largest_mesh(verts, faces):
|
342 |
+
'''
|
343 |
+
verts: Numpy array or Torch.Tensor (N, 3)
|
344 |
+
faces: Numpy array (N, 3)
|
345 |
+
'''
|
346 |
+
if torch.is_tensor(faces):
|
347 |
+
verts = verts.squeeze().detach().cpu().numpy()
|
348 |
+
faces = faces.squeeze().int().detach().cpu().numpy()
|
349 |
+
|
350 |
+
A = adjacency_matrix(faces)
|
351 |
+
num, conn_idx, conn_size = connected_components(A)
|
352 |
+
if num == 0:
|
353 |
+
v_large, f_large = verts, faces
|
354 |
+
else:
|
355 |
+
max_idx = conn_size.argmax() # find the index of the largest component
|
356 |
+
v_large = verts[conn_idx==max_idx] # keep points on the largest component
|
357 |
+
|
358 |
+
if True:
|
359 |
+
mesh_largest = trimesh.Trimesh(verts, faces)
|
360 |
+
connected_comp = mesh_largest.split(only_watertight=False)
|
361 |
+
mesh_largest = connected_comp[max_idx]
|
362 |
+
v_large, f_large = mesh_largest.vertices, mesh_largest.faces
|
363 |
+
v_large = v_large.astype(np.float32)
|
364 |
+
return v_large, f_large
|
365 |
+
|
366 |
+
def update_recursive(dict1, dict2):
|
367 |
+
''' Update two config dictionaries recursively.
|
368 |
+
|
369 |
+
Args:
|
370 |
+
dict1 (dict): first dictionary to be updated
|
371 |
+
dict2 (dict): second dictionary which entries should be used
|
372 |
+
|
373 |
+
'''
|
374 |
+
for k, v in dict2.items():
|
375 |
+
if k not in dict1:
|
376 |
+
dict1[k] = dict()
|
377 |
+
if isinstance(v, dict):
|
378 |
+
update_recursive(dict1[k], v)
|
379 |
+
else:
|
380 |
+
dict1[k] = v
|
381 |
+
|
382 |
+
def scale2onet(p, scale=1.2):
|
383 |
+
'''
|
384 |
+
Scale the point cloud from SAP to ONet range
|
385 |
+
'''
|
386 |
+
return (p - 0.5) * scale
|
387 |
+
|
388 |
+
def update_optimizer(inputs, cfg, epoch, model=None, schedule=None):
|
389 |
+
if model is not None:
|
390 |
+
if schedule is not None:
|
391 |
+
optimizer = torch.optim.Adam([
|
392 |
+
{"params": model.parameters(),
|
393 |
+
"lr": schedule[0].get_learning_rate(epoch)},
|
394 |
+
{"params": inputs,
|
395 |
+
"lr": schedule[1].get_learning_rate(epoch)}])
|
396 |
+
elif 'lr' in cfg['train']:
|
397 |
+
optimizer = torch.optim.Adam([
|
398 |
+
{"params": model.parameters(),
|
399 |
+
"lr": float(cfg['train']['lr'])},
|
400 |
+
{"params": inputs,
|
401 |
+
"lr": float(cfg['train']['lr_pcl'])}])
|
402 |
+
else:
|
403 |
+
raise Exception('no known learning rate')
|
404 |
+
else:
|
405 |
+
if schedule is not None:
|
406 |
+
optimizer = torch.optim.Adam([inputs], lr=schedule[0].get_learning_rate(epoch))
|
407 |
+
else:
|
408 |
+
optimizer = torch.optim.Adam([inputs], lr=float(cfg['train']['lr_pcl']))
|
409 |
+
|
410 |
+
return optimizer
|
411 |
+
|
412 |
+
|
413 |
+
def is_url(url):
|
414 |
+
scheme = urllib.parse.urlparse(url).scheme
|
415 |
+
return scheme in ('http', 'https')
|
416 |
+
|
417 |
+
def load_url(url):
|
418 |
+
'''Load a module dictionary from url.
|
419 |
+
|
420 |
+
Args:
|
421 |
+
url (str): url to saved model
|
422 |
+
'''
|
423 |
+
print(url)
|
424 |
+
print('=> Loading checkpoint from url...')
|
425 |
+
state_dict = model_zoo.load_url(url, progress=True)
|
426 |
+
|
427 |
+
return state_dict
|
428 |
+
|
429 |
+
|
430 |
+
class GaussianSmoothing(nn.Module):
|
431 |
+
"""
|
432 |
+
Apply gaussian smoothing on a
|
433 |
+
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
|
434 |
+
in the input using a depthwise convolution.
|
435 |
+
Arguments:
|
436 |
+
channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well.
|
437 |
+
kernel_size (int, sequence): Size of the gaussian kernel.
|
438 |
+
sigma (float, sequence): Standard deviation of the gaussian kernel.
|
439 |
+
dim (int, optional): The number of dimensions of the data.
|
440 |
+
Default value is 2 (spatial).
|
441 |
+
"""
|
442 |
+
def __init__(self, channels, kernel_size, sigma, dim=3):
|
443 |
+
super(GaussianSmoothing, self).__init__()
|
444 |
+
if isinstance(kernel_size, numbers.Number):
|
445 |
+
kernel_size = [kernel_size] * dim
|
446 |
+
if isinstance(sigma, numbers.Number):
|
447 |
+
sigma = [sigma] * dim
|
448 |
+
|
449 |
+
# The gaussian kernel is the product of the
|
450 |
+
# gaussian function of each dimension.
|
451 |
+
kernel = 1
|
452 |
+
meshgrids = torch.meshgrid(
|
453 |
+
[
|
454 |
+
torch.arange(size, dtype=torch.float32)
|
455 |
+
for size in kernel_size
|
456 |
+
]
|
457 |
+
)
|
458 |
+
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
459 |
+
mean = (size - 1) / 2
|
460 |
+
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
|
461 |
+
torch.exp(-((mgrid - mean) / std) ** 2 / 2)
|
462 |
+
|
463 |
+
# Make sure sum of values in gaussian kernel equals 1.
|
464 |
+
kernel = kernel / torch.sum(kernel)
|
465 |
+
|
466 |
+
# Reshape to depthwise convolutional weight
|
467 |
+
kernel = kernel.view(1, 1, *kernel.size())
|
468 |
+
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
469 |
+
|
470 |
+
self.register_buffer('weight', kernel)
|
471 |
+
self.groups = channels
|
472 |
+
|
473 |
+
if dim == 1:
|
474 |
+
self.conv = F.conv1d
|
475 |
+
elif dim == 2:
|
476 |
+
self.conv = F.conv2d
|
477 |
+
elif dim == 3:
|
478 |
+
self.conv = F.conv3d
|
479 |
+
else:
|
480 |
+
raise RuntimeError(
|
481 |
+
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
|
482 |
+
)
|
483 |
+
|
484 |
+
def forward(self, input):
|
485 |
+
"""
|
486 |
+
Apply gaussian filter to input.
|
487 |
+
Arguments:
|
488 |
+
input (torch.Tensor): Input to apply gaussian filter on.
|
489 |
+
Returns:
|
490 |
+
filtered (torch.Tensor): Filtered output.
|
491 |
+
"""
|
492 |
+
return self.conv(input, weight=self.weight, groups=self.groups)
|
493 |
+
|
494 |
+
# Originally from https://github.com/amosgropp/IGR/blob/0db06b1273/code/utils/general.py
|
495 |
+
def get_learning_rate_schedules(schedule_specs):
|
496 |
+
|
497 |
+
schedules = []
|
498 |
+
|
499 |
+
for key in schedule_specs.keys():
|
500 |
+
schedules.append(StepLearningRateSchedule(
|
501 |
+
schedule_specs[key]['initial'],
|
502 |
+
schedule_specs[key]["interval"],
|
503 |
+
schedule_specs[key]["factor"],
|
504 |
+
schedule_specs[key]["final"]))
|
505 |
+
return schedules
|
506 |
+
|
507 |
+
class LearningRateSchedule:
|
508 |
+
def get_learning_rate(self, epoch):
|
509 |
+
pass
|
510 |
+
class StepLearningRateSchedule(LearningRateSchedule):
|
511 |
+
def __init__(self, initial, interval, factor, final=1e-6):
|
512 |
+
self.initial = float(initial)
|
513 |
+
self.interval = interval
|
514 |
+
self.factor = factor
|
515 |
+
self.final = float(final)
|
516 |
+
|
517 |
+
def get_learning_rate(self, epoch):
|
518 |
+
lr = np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6)
|
519 |
+
if lr > self.final:
|
520 |
+
return lr
|
521 |
+
else:
|
522 |
+
return self.final
|
523 |
+
|
524 |
+
def adjust_learning_rate(lr_schedules, optimizer, epoch):
|
525 |
+
for i, param_group in enumerate(optimizer.param_groups):
|
526 |
+
param_group["lr"] = lr_schedules[i].get_learning_rate(epoch)
|
models/__init__.py
ADDED
File without changes
|
models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (129 Bytes). View file
|
|
models/__pycache__/model.cpython-39.pyc
ADDED
Binary file (5.56 kB). View file
|
|
models/model.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from statistics import mean
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
import torch as th
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
from jaxtyping import Float, Int
|
8 |
+
import numpy as np
|
9 |
+
from torch_geometric.nn.conv import GATv2Conv
|
10 |
+
|
11 |
+
from models.SAP.dpsr import DPSR
|
12 |
+
from models.SAP.model import PSR2Mesh
|
13 |
+
|
14 |
+
# Constants
|
15 |
+
|
16 |
+
th.manual_seed(0)
|
17 |
+
np.random.seed(0)
|
18 |
+
|
19 |
+
BATCH_SIZE = 1 # BS
|
20 |
+
|
21 |
+
IN_DIM = 1
|
22 |
+
OUT_DIM = 1
|
23 |
+
LATENT_DIM = 32
|
24 |
+
|
25 |
+
DROPOUT_PROB = 0.1
|
26 |
+
GRID_SIZE = 128
|
27 |
+
|
28 |
+
def generate_grid_edge_list(gs: int = 128):
|
29 |
+
grid_edge_list = []
|
30 |
+
|
31 |
+
for k in range(gs):
|
32 |
+
for j in range(gs):
|
33 |
+
for i in range(gs):
|
34 |
+
current_idx = i + gs*j + k*gs*gs
|
35 |
+
if (i - 1) >= 0:
|
36 |
+
grid_edge_list.append([current_idx, i-1 + gs*j + k*gs*gs])
|
37 |
+
if (i + 1) < gs:
|
38 |
+
grid_edge_list.append([current_idx, i+1 + gs*j + k*gs*gs])
|
39 |
+
if (j - 1) >= 0:
|
40 |
+
grid_edge_list.append([current_idx, i + gs*(j-1) + k*gs*gs])
|
41 |
+
if (j + 1) < gs:
|
42 |
+
grid_edge_list.append([current_idx, i + gs*(j+1) + k*gs*gs])
|
43 |
+
if (k - 1) >= 0:
|
44 |
+
grid_edge_list.append([current_idx, i + gs*j + (k-1)*gs*gs])
|
45 |
+
if (k + 1) < gs:
|
46 |
+
grid_edge_list.append([current_idx, i + gs*j + (k+1)*gs*gs])
|
47 |
+
return grid_edge_list
|
48 |
+
|
49 |
+
GRID_EDGE_LIST = generate_grid_edge_list(GRID_SIZE)
|
50 |
+
GRID_EDGE_LIST = th.tensor(GRID_EDGE_LIST, dtype=th.int)
|
51 |
+
GRID_EDGE_LIST = GRID_EDGE_LIST.T
|
52 |
+
# GRID_EDGE_LIST = GRID_EDGE_LIST.to(th.device("cuda"))
|
53 |
+
GRID_EDGE_LIST.requires_grad = False # Do not forget to delete it if train
|
54 |
+
|
55 |
+
|
56 |
+
class FormOptimizer(th.nn.Module):
|
57 |
+
def __init__(self) -> None:
|
58 |
+
super().__init__()
|
59 |
+
|
60 |
+
layers = []
|
61 |
+
|
62 |
+
self.gconv1 = GATv2Conv(in_channels=IN_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB)
|
63 |
+
self.gconv2 = GATv2Conv(in_channels=LATENT_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB)
|
64 |
+
|
65 |
+
self.actv = th.nn.Sigmoid()
|
66 |
+
self.head = th.nn.Linear(in_features=LATENT_DIM, out_features=OUT_DIM)
|
67 |
+
|
68 |
+
def forward(self,
|
69 |
+
field: Float[th.Tensor, "GS GS GS"]) -> Float[th.Tensor, "GS GS GS"]:
|
70 |
+
"""
|
71 |
+
Args:
|
72 |
+
field (Tensor [GS, GS, GS]): vertices and normals tensor.
|
73 |
+
"""
|
74 |
+
vertex_features = field.clone()
|
75 |
+
vertex_features = vertex_features.reshape(GRID_SIZE*GRID_SIZE*GRID_SIZE, IN_DIM)
|
76 |
+
|
77 |
+
vertex_features = self.gconv1(x=vertex_features, edge_index=GRID_EDGE_LIST)
|
78 |
+
vertex_features = self.gconv2(x=vertex_features, edge_index=GRID_EDGE_LIST)
|
79 |
+
field_delta = self.head(self.actv(vertex_features))
|
80 |
+
|
81 |
+
field_delta = field_delta.reshape(BATCH_SIZE, GRID_SIZE, GRID_SIZE, GRID_SIZE)
|
82 |
+
field_delta += field # field_delta carries the gradient
|
83 |
+
field_delta = th.clamp(field_delta, min=-0.5, max=0.5)
|
84 |
+
|
85 |
+
return field_delta
|
86 |
+
|
87 |
+
class Model(pl.LightningModule):
|
88 |
+
def __init__(self):
|
89 |
+
super().__init__()
|
90 |
+
self.form_optimizer = FormOptimizer()
|
91 |
+
|
92 |
+
self.dpsr = DPSR([GRID_SIZE, GRID_SIZE, GRID_SIZE], sig=0.0)
|
93 |
+
self.field2mesh = PSR2Mesh().apply
|
94 |
+
|
95 |
+
self.metric = th.nn.MSELoss()
|
96 |
+
|
97 |
+
self.val_losses = []
|
98 |
+
self.train_losses = []
|
99 |
+
|
100 |
+
def log_h5(self, points, normals):
|
101 |
+
dset = self.log_points_file.create_dataset(
|
102 |
+
name=str(self.h5_frame),
|
103 |
+
shape=points.shape,
|
104 |
+
dtype=np.float16,
|
105 |
+
compression="gzip")
|
106 |
+
dset[:] = points
|
107 |
+
dset = self.log_normals_file.create_dataset(
|
108 |
+
name=str(self.h5_frame),
|
109 |
+
shape=normals.shape,
|
110 |
+
dtype=np.float16,
|
111 |
+
compression="gzip")
|
112 |
+
dset[:] = normals
|
113 |
+
self.h5_frame += 1
|
114 |
+
|
115 |
+
def forward(self,
|
116 |
+
v: Float[th.Tensor, "BS N 3"],
|
117 |
+
n: Float[th.Tensor, "BS N 3"]) -> Tuple[Float[th.Tensor, "BS N 3"], # v - vertices
|
118 |
+
Int[th.Tensor, "2 E"], # f - faces
|
119 |
+
Float[th.Tensor, "BS N 3"], # n - vertices normals
|
120 |
+
Float[th.Tensor, "BS GR GR GR"]]: # field:
|
121 |
+
field = self.dpsr(v, n)
|
122 |
+
field = self.form_optimizer(field)
|
123 |
+
v, f, n = self.field2mesh(field)
|
124 |
+
return v, f, n, field
|
125 |
+
|
126 |
+
def training_step(self, batch, batch_idx) -> Float[th.Tensor, "1"]:
|
127 |
+
vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch
|
128 |
+
|
129 |
+
mask = th.rand((vertices.shape[1], ), device=th.device("cuda")) < (random.random() / 2.0 + 0.5)
|
130 |
+
vertices = vertices[:, mask]
|
131 |
+
vertices_normals = vertices_normals[:, mask]
|
132 |
+
|
133 |
+
vr, fr, nr, field_r = model(vertices, vertices_normals)
|
134 |
+
|
135 |
+
loss = self.metric(field_r, field_gt)
|
136 |
+
train_per_step_loss = loss.item()
|
137 |
+
self.train_losses.append(train_per_step_loss)
|
138 |
+
|
139 |
+
return loss
|
140 |
+
|
141 |
+
def on_train_epoch_end(self):
|
142 |
+
mean_train_per_epoch_loss = mean(self.train_losses)
|
143 |
+
self.log("mean_train_per_epoch_loss", mean_train_per_epoch_loss, on_step=False, on_epoch=True)
|
144 |
+
self.train_losses = []
|
145 |
+
|
146 |
+
def validation_step(self, batch, batch_idx):
|
147 |
+
vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch
|
148 |
+
|
149 |
+
vr, fr, nr, field_r = model(vertices, vertices_normals)
|
150 |
+
|
151 |
+
loss = self.metric(field_r, field_gt)
|
152 |
+
val_per_step_loss = loss.item()
|
153 |
+
self.val_losses.append(val_per_step_loss)
|
154 |
+
return loss
|
155 |
+
|
156 |
+
def on_validation_epoch_end(self):
|
157 |
+
mean_val_per_epoch_loss = mean(self.val_losses)
|
158 |
+
self.log("mean_val_per_epoch_loss", mean_val_per_epoch_loss, on_step=False, on_epoch=True)
|
159 |
+
self.val_losses = []
|
160 |
+
|
161 |
+
def configure_optimizers(self):
|
162 |
+
optimizer = th.optim.Adam(self.parameters(), lr=LR)
|
163 |
+
scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
|
164 |
+
|
165 |
+
return {
|
166 |
+
"optimizer": optimizer,
|
167 |
+
"lr_scheduler": {
|
168 |
+
"scheduler": scheduler,
|
169 |
+
"monitor": "mean_val_per_epoch_loss",
|
170 |
+
"interval": "epoch",
|
171 |
+
"frequency": 1,
|
172 |
+
# If set to `True`, will enforce that the value specified 'monitor'
|
173 |
+
# is available when the scheduler is updated, thus stopping
|
174 |
+
# training if not found. If set to `False`, it will only produce a warning
|
175 |
+
"strict": True,
|
176 |
+
# If using the `LearningRateMonitor` callback to monitor the
|
177 |
+
# learning rate progress, this keyword can be used to specify
|
178 |
+
# a custom logged name
|
179 |
+
"name": None,
|
180 |
+
}
|
181 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohttp==3.8.4
|
2 |
+
aiosignal==1.3.1
|
3 |
+
ansicon==1.89.0
|
4 |
+
anyio==3.6.2
|
5 |
+
arrow==1.2.3
|
6 |
+
asttokens==2.2.1
|
7 |
+
async-timeout==4.0.2
|
8 |
+
attrs==23.1.0
|
9 |
+
backcall==0.2.0
|
10 |
+
beautifulsoup4==4.12.2
|
11 |
+
blessed==1.20.0
|
12 |
+
blinker==1.6.2
|
13 |
+
certifi==2023.5.7
|
14 |
+
charset-normalizer==3.1.0
|
15 |
+
click==8.1.3
|
16 |
+
colorama==0.4.6
|
17 |
+
comm==0.1.3
|
18 |
+
ConfigArgParse==1.5.3
|
19 |
+
croniter==1.3.14
|
20 |
+
dash==2.9.3
|
21 |
+
dash-core-components==2.0.0
|
22 |
+
dash-html-components==2.0.0
|
23 |
+
dash-table==5.0.0
|
24 |
+
dateutils==0.6.12
|
25 |
+
debugpy==1.6.7
|
26 |
+
decorator==5.1.1
|
27 |
+
deepdiff==6.3.0
|
28 |
+
executing==1.2.0
|
29 |
+
fastapi==0.88.0
|
30 |
+
fastjsonschema==2.16.3
|
31 |
+
Flask==2.3.2
|
32 |
+
frozenlist==1.3.3
|
33 |
+
fsspec==2023.5.0
|
34 |
+
fvcore==0.1.5.post20221221
|
35 |
+
h11==0.14.0
|
36 |
+
idna==3.4
|
37 |
+
imageio==2.28.1
|
38 |
+
importlib-metadata==6.6.0
|
39 |
+
inquirer==3.1.3
|
40 |
+
iopath==0.1.10
|
41 |
+
ipykernel==6.23.1
|
42 |
+
ipython==8.13.2
|
43 |
+
ipywidgets==8.0.6
|
44 |
+
itsdangerous==2.1.2
|
45 |
+
jaxtyping==0.2.19
|
46 |
+
jedi==0.18.2
|
47 |
+
Jinja2==3.1.2
|
48 |
+
jinxed==1.2.0
|
49 |
+
joblib==1.2.0
|
50 |
+
jsonschema==4.17.3
|
51 |
+
jupyter_client==8.2.0
|
52 |
+
jupyter_core==5.3.0
|
53 |
+
jupyterlab-widgets==3.0.7
|
54 |
+
lazy_loader==0.2
|
55 |
+
libigl==2.4.1
|
56 |
+
lightning==2.0.2
|
57 |
+
lightning-cloud==0.5.36
|
58 |
+
lightning-utilities==0.8.0
|
59 |
+
markdown-it-py==2.2.0
|
60 |
+
MarkupSafe==2.1.2
|
61 |
+
matplotlib-inline==0.1.6
|
62 |
+
mdurl==0.1.2
|
63 |
+
multidict==6.0.4
|
64 |
+
nbformat==5.7.0
|
65 |
+
nest-asyncio==1.5.6
|
66 |
+
networkx==3.1
|
67 |
+
numpy==1.24.3
|
68 |
+
open3d==0.17.0
|
69 |
+
ordered-set==4.1.0
|
70 |
+
packaging==23.1
|
71 |
+
parso==0.8.3
|
72 |
+
pickleshare==0.7.5
|
73 |
+
Pillow==9.5.0
|
74 |
+
platformdirs==3.5.1
|
75 |
+
plotly==5.14.1
|
76 |
+
plyfile==0.9
|
77 |
+
portalocker==2.7.0
|
78 |
+
prompt-toolkit==3.0.38
|
79 |
+
psutil==5.9.5
|
80 |
+
pure-eval==0.2.2
|
81 |
+
pydantic==1.10.7
|
82 |
+
Pygments==2.15.1
|
83 |
+
PyJWT==2.7.0
|
84 |
+
pyparsing==3.0.9
|
85 |
+
pyrsistent==0.19.3
|
86 |
+
PySimpleGUI==4.60.4
|
87 |
+
python-dateutil==2.8.2
|
88 |
+
python-editor==1.0.4
|
89 |
+
python-multipart==0.0.6
|
90 |
+
pytorch-lightning==2.0.2
|
91 |
+
pytz==2023.3
|
92 |
+
PyWavelets==1.4.1
|
93 |
+
pywin32==306
|
94 |
+
PyYAML==6.0
|
95 |
+
pyzmq==25.0.2
|
96 |
+
readchar==4.0.5
|
97 |
+
requests==2.30.0
|
98 |
+
rich==13.3.5
|
99 |
+
scikit-image==0.20.0
|
100 |
+
scikit-learn==1.2.2
|
101 |
+
scipy==1.9.1
|
102 |
+
six==1.16.0
|
103 |
+
sniffio==1.3.0
|
104 |
+
soupsieve==2.4.1
|
105 |
+
stack-data==0.6.2
|
106 |
+
starlette==0.22.0
|
107 |
+
starsessions==1.3.0
|
108 |
+
tabulate==0.9.0
|
109 |
+
tenacity==8.2.2
|
110 |
+
termcolor==2.3.0
|
111 |
+
threadpoolctl==3.1.0
|
112 |
+
tifffile==2023.4.12
|
113 |
+
torch==1.13.1+cu116
|
114 |
+
torch-cluster==1.6.1+pt113cu116
|
115 |
+
torch-geometric==2.3.1
|
116 |
+
torch-scatter==2.1.1+pt113cu116
|
117 |
+
torch-sparse==0.6.17+pt113cu116
|
118 |
+
torch-spline-conv==1.2.2+pt113cu116
|
119 |
+
torchaudio==0.13.1
|
120 |
+
torchmetrics==0.11.4
|
121 |
+
torchvision==0.14.1+cu116
|
122 |
+
tornado==6.3.2
|
123 |
+
tqdm==4.65.0
|
124 |
+
traitlets==5.9.0
|
125 |
+
trimesh==3.21.6
|
126 |
+
typeguard==4.0.0
|
127 |
+
typing_extensions==4.5.0
|
128 |
+
urllib3==2.0.2
|
129 |
+
uvicorn==0.22.0
|
130 |
+
wcwidth==0.2.6
|
131 |
+
websocket-client==1.5.1
|
132 |
+
websockets==11.0.3
|
133 |
+
Werkzeug==2.3.4
|
134 |
+
widgetsnbextension==4.0.7
|
135 |
+
yacs==0.1.8
|
136 |
+
yarl==1.9.2
|
137 |
+
zipp==3.15.0
|
138 |
+
aiofiles==23.1.0
|
139 |
+
aiohttp==3.8.4
|
140 |
+
aiosignal==1.3.1
|
141 |
+
altair==5.0.0
|
142 |
+
ansicon==1.89.0
|
143 |
+
anyio==3.6.2
|
144 |
+
arrow==1.2.3
|
145 |
+
asttokens==2.2.1
|
146 |
+
async-timeout==4.0.2
|
147 |
+
attrs==23.1.0
|
148 |
+
backcall==0.2.0
|
149 |
+
beautifulsoup4==4.12.2
|
150 |
+
blessed==1.20.0
|
151 |
+
blinker==1.6.2
|
152 |
+
certifi==2023.5.7
|
153 |
+
charset-normalizer==3.1.0
|
154 |
+
click==8.1.3
|
155 |
+
colorama==0.4.6
|
156 |
+
comm==0.1.3
|
157 |
+
ConfigArgParse==1.5.3
|
158 |
+
contourpy==1.0.7
|
159 |
+
croniter==1.3.14
|
160 |
+
cycler==0.11.0
|
161 |
+
dash==2.9.3
|
162 |
+
dash-core-components==2.0.0
|
163 |
+
dash-html-components==2.0.0
|
164 |
+
dash-table==5.0.0
|
165 |
+
dateutils==0.6.12
|
166 |
+
debugpy==1.6.7
|
167 |
+
decorator==5.1.1
|
168 |
+
deepdiff==6.3.0
|
169 |
+
executing==1.2.0
|
170 |
+
fastapi==0.88.0
|
171 |
+
fastjsonschema==2.16.3
|
172 |
+
ffmpy==0.3.0
|
173 |
+
filelock==3.12.0
|
174 |
+
Flask==2.3.2
|
175 |
+
fonttools==4.39.4
|
176 |
+
frozenlist==1.3.3
|
177 |
+
fsspec==2023.5.0
|
178 |
+
fvcore==0.1.5.post20221221
|
179 |
+
gradio==3.30.0
|
180 |
+
gradio_client==0.2.5
|
181 |
+
h11==0.14.0
|
182 |
+
httpcore==0.17.0
|
183 |
+
httpx==0.24.0
|
184 |
+
huggingface-hub==0.14.1
|
185 |
+
idna==3.4
|
186 |
+
imageio==2.28.1
|
187 |
+
importlib-metadata==6.6.0
|
188 |
+
importlib-resources==5.12.0
|
189 |
+
inquirer==3.1.3
|
190 |
+
iopath==0.1.10
|
191 |
+
ipykernel==6.23.1
|
192 |
+
ipython==8.13.2
|
193 |
+
ipywidgets==8.0.6
|
194 |
+
itsdangerous==2.1.2
|
195 |
+
jaxtyping==0.2.19
|
196 |
+
jedi==0.18.2
|
197 |
+
Jinja2==3.1.2
|
198 |
+
jinxed==1.2.0
|
199 |
+
joblib==1.2.0
|
200 |
+
jsonschema==4.17.3
|
201 |
+
jupyter_client==8.2.0
|
202 |
+
jupyter_core==5.3.0
|
203 |
+
jupyterlab-widgets==3.0.7
|
204 |
+
kiwisolver==1.4.4
|
205 |
+
lazy_loader==0.2
|
206 |
+
libigl==2.4.1
|
207 |
+
lightning==2.0.2
|
208 |
+
lightning-cloud==0.5.36
|
209 |
+
lightning-utilities==0.8.0
|
210 |
+
linkify-it-py==2.0.2
|
211 |
+
markdown-it-py==2.2.0
|
212 |
+
MarkupSafe==2.1.2
|
213 |
+
matplotlib==3.7.1
|
214 |
+
matplotlib-inline==0.1.6
|
215 |
+
mdit-py-plugins==0.3.3
|
216 |
+
mdurl==0.1.2
|
217 |
+
multidict==6.0.4
|
218 |
+
nbformat==5.7.0
|
219 |
+
nest-asyncio==1.5.6
|
220 |
+
networkx==3.1
|
221 |
+
numpy==1.24.3
|
222 |
+
open3d==0.17.0
|
223 |
+
ordered-set==4.1.0
|
224 |
+
orjson==3.8.12
|
225 |
+
packaging==23.1
|
226 |
+
pandas==2.0.1
|
227 |
+
parso==0.8.3
|
228 |
+
pickleshare==0.7.5
|
229 |
+
Pillow==9.5.0
|
230 |
+
platformdirs==3.5.1
|
231 |
+
plotly==5.14.1
|
232 |
+
plyfile==0.9
|
233 |
+
portalocker==2.7.0
|
234 |
+
prompt-toolkit==3.0.38
|
235 |
+
psutil==5.9.5
|
236 |
+
pure-eval==0.2.2
|
237 |
+
pydantic==1.10.7
|
238 |
+
pydub==0.25.1
|
239 |
+
Pygments==2.15.1
|
240 |
+
PyJWT==2.7.0
|
241 |
+
pyparsing==3.0.9
|
242 |
+
pyrsistent==0.19.3
|
243 |
+
PySimpleGUI==4.60.4
|
244 |
+
python-dateutil==2.8.2
|
245 |
+
python-editor==1.0.4
|
246 |
+
python-multipart==0.0.6
|
247 |
+
pytorch-lightning==2.0.2
|
248 |
+
pytz==2023.3
|
249 |
+
PyWavelets==1.4.1
|
250 |
+
pywin32==306
|
251 |
+
PyYAML==6.0
|
252 |
+
pyzmq==25.0.2
|
253 |
+
readchar==4.0.5
|
254 |
+
requests==2.30.0
|
255 |
+
rich==13.3.5
|
256 |
+
scikit-image==0.20.0
|
257 |
+
scikit-learn==1.2.2
|
258 |
+
scipy==1.9.1
|
259 |
+
semantic-version==2.10.0
|
260 |
+
six==1.16.0
|
261 |
+
sniffio==1.3.0
|
262 |
+
soupsieve==2.4.1
|
263 |
+
stack-data==0.6.2
|
264 |
+
starlette==0.22.0
|
265 |
+
starsessions==1.3.0
|
266 |
+
tabulate==0.9.0
|
267 |
+
tenacity==8.2.2
|
268 |
+
termcolor==2.3.0
|
269 |
+
threadpoolctl==3.1.0
|
270 |
+
tifffile==2023.4.12
|
271 |
+
toolz==0.12.0
|
272 |
+
torch==1.13.1+cu116
|
273 |
+
torch-cluster==1.6.1+pt113cu116
|
274 |
+
torch-geometric==2.3.1
|
275 |
+
torch-scatter==2.1.1+pt113cu116
|
276 |
+
torch-sparse==0.6.17+pt113cu116
|
277 |
+
torch-spline-conv==1.2.2+pt113cu116
|
278 |
+
torchaudio==0.13.1
|
279 |
+
torchmetrics==0.11.4
|
280 |
+
torchvision==0.14.1+cu116
|
281 |
+
tornado==6.3.2
|
282 |
+
tqdm==4.65.0
|
283 |
+
traitlets==5.9.0
|
284 |
+
trimesh==3.21.6
|
285 |
+
typeguard==4.0.0
|
286 |
+
typing_extensions==4.5.0
|
287 |
+
tzdata==2023.3
|
288 |
+
uc-micro-py==1.0.2
|
289 |
+
urllib3==2.0.2
|
290 |
+
uvicorn==0.22.0
|
291 |
+
wcwidth==0.2.6
|
292 |
+
websocket-client==1.5.1
|
293 |
+
websockets==11.0.3
|
294 |
+
Werkzeug==2.3.4
|
295 |
+
widgetsnbextension==4.0.7
|
296 |
+
yacs==0.1.8
|
297 |
+
yarl==1.9.2
|
298 |
+
zipp==3.15.0
|