MVV commited on
Commit
583d0a8
1 Parent(s): be22bb4

Upload 19 files

Browse files
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+
4
+ import pytorch_lightning as pl
5
+ import torch as th
6
+ import open3d as o3d
7
+ import numpy as np
8
+ import trimesh as tm
9
+
10
+ from models.model import Model
11
+
12
+ model = Model()
13
+ ckpg = th.load("./checkpoints/epoch=99-step=6000.ckpt")
14
+ model.load_state_dict(ckpg["state_dict"])
15
+
16
+
17
+ def process_mesh(mesh_file_name):
18
+
19
+ mesh = tm.load_mesh(mesh_file_name)
20
+
21
+ v = th.tensor(mesh.vertices, dtype=th.float)
22
+ n = th.tensor(mesh.vertex_normals, dtype=th.float)
23
+
24
+ with th.no_grad():
25
+ v, f, n, _ = model(v.unsqueeze(0), n.unsqueeze(0))
26
+
27
+ mesh = tm.Trimesh(vertices=v.squeeze(0),
28
+ faces=f.squeeze(0),
29
+ vertex_normals=n.squeeze(0))
30
+ obj_path = "./sample.obj"
31
+ mesh.export(obj_path)
32
+
33
+ return obj_path
34
+
35
+
36
+ demo = gr.Interface(
37
+ fn=process_mesh,
38
+ inputs=gr.Model3D(),
39
+ outputs=gr.Model3D(
40
+ clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
41
+ examples=[
42
+ [os.path.join(os.path.dirname(__file__), "files\\bunny_n1_hi_50.obj")],
43
+ [os.path.join(os.path.dirname(__file__), "files\\child_n2_80.obj")],
44
+ [os.path.join(os.path.dirname(__file__), "files\\eight_n3_70.obj")],
45
+ ],
46
+ )
47
+
48
+ if __name__ == "__main__":
49
+ demo.launch()
checkpoints/epoch=99-step=6000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3a025dfa88edbf34bf7cf2b69c554cb01ab7d61b4d8cc699a2a6753e14dbdea
3
+ size 4308343
files/bunny_n1_hi_50.obj ADDED
The diff for this file is too large to render. See raw diff
 
files/child_n2_80.obj ADDED
The diff for this file is too large to render. See raw diff
 
files/eight_n3_70.obj ADDED
The diff for this file is too large to render. See raw diff
 
models/SAP/__init__.py ADDED
File without changes
models/SAP/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (133 Bytes). View file
 
models/SAP/__pycache__/dpsr.cpython-39.pyc ADDED
Binary file (2.47 kB). View file
 
models/SAP/__pycache__/model.cpython-39.pyc ADDED
Binary file (4.08 kB). View file
 
models/SAP/__pycache__/utils.cpython-39.pyc ADDED
Binary file (15 kB). View file
 
models/SAP/dpsr.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .utils import spec_gaussian_filter, fftfreqs, img, grid_interp, point_rasterize
4
+ import numpy as np
5
+ import torch.fft
6
+
7
+ class DPSR(nn.Module):
8
+ def __init__(self, res, sig=10, scale=True, shift=True):
9
+ """
10
+ :param res: tuple of output field resolution. eg., (128,128)
11
+ :param sig: degree of gaussian smoothing
12
+ """
13
+ super(DPSR, self).__init__()
14
+ self.res = res
15
+ self.sig = sig
16
+ self.dim = len(res)
17
+ self.denom = np.prod(res)
18
+ G = spec_gaussian_filter(res=res, sig=sig).float()
19
+ # self.G.requires_grad = False # True, if we also make sig a learnable parameter
20
+ self.omega = fftfreqs(res, dtype=torch.float32)
21
+ self.scale = scale
22
+ self.shift = shift
23
+ self.register_buffer("G", G)
24
+
25
+ def forward(self, V, N):
26
+ """
27
+ :param V: (batch, nv, 2 or 3) tensor for point cloud coordinates
28
+ :param N: (batch, nv, 2 or 3) tensor for point normals
29
+ :return phi: (batch, res, res, ...) tensor of output indicator function field
30
+ """
31
+ assert(V.shape == N.shape) # [b, nv, ndims]
32
+ ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2]
33
+
34
+ ras_s = torch.fft.rfftn(ras_p, dim=(2,3,4))
35
+ ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+1))+[self.dim+1, 1]))
36
+ N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1]
37
+
38
+ omega = fftfreqs(self.res, dtype=torch.float32).unsqueeze(-1) # [dim0, dim1, dim2/2+1, n_dim, 1]
39
+ omega *= 2 * np.pi # normalize frequencies
40
+ omega = omega.to(V.device)
41
+
42
+ DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2)
43
+
44
+ Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1]
45
+ Phi = DivN / (Lap+1e-6) # [b, dim0, dim1, dim2/2+1, 2]
46
+ Phi = Phi.permute(*tuple([list(range(1,self.dim+2)) + [0]])) # [dim0, dim1, dim2/2+1, 2, b]
47
+ Phi[tuple([0] * self.dim)] = 0
48
+ Phi = Phi.permute(*tuple([[self.dim+1] + list(range(self.dim+1))])) # [b, dim0, dim1, dim2/2+1, 2]
49
+
50
+ phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1,2,3))
51
+
52
+ if self.shift or self.scale:
53
+ # ensure values at points are zero
54
+ fv = grid_interp(phi.unsqueeze(-1), V, batched=True).squeeze(-1) # [b, nv]
55
+ if self.shift: # offset points to have mean of 0
56
+ offset = torch.mean(fv, dim=-1) # [b,]
57
+ phi -= offset.view(*tuple([-1] + [1] * self.dim))
58
+
59
+ phi = phi.permute(*tuple([list(range(1,self.dim+1)) + [0]]))
60
+ fv0 = phi[tuple([0] * self.dim)] # [b,]
61
+ phi = phi.permute(*tuple([[self.dim] + list(range(self.dim))]))
62
+
63
+ if self.scale:
64
+ phi = -phi / torch.abs(fv0.view(*tuple([-1]+[1] * self.dim))) *0.5
65
+ return phi
models/SAP/model.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import time
4
+ from .utils import point_rasterize, grid_interp, mc_from_psr, \
5
+ calc_inters_points
6
+ from .dpsr import DPSR
7
+ import torch.nn as nn
8
+
9
+ class PSR2Mesh(torch.autograd.Function):
10
+ @staticmethod
11
+ def forward(ctx, psr_grid):
12
+ """
13
+ In the forward pass we receive a Tensor containing the input and return
14
+ a Tensor containing the output. ctx is a context object that can be used
15
+ to stash information for backward computation. You can cache arbitrary
16
+ objects for use in the backward pass using the ctx.save_for_backward method.
17
+ """
18
+ verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
19
+ verts = verts.unsqueeze(0)
20
+ faces = faces.unsqueeze(0)
21
+ normals = normals.unsqueeze(0)
22
+
23
+ res = torch.tensor(psr_grid.detach().shape[2])
24
+ ctx.save_for_backward(verts, normals, res)
25
+
26
+ return verts, faces, normals
27
+
28
+ @staticmethod
29
+ def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals):
30
+ """
31
+ In the backward pass we receive a Tensor containing the gradient of the loss
32
+ with respect to the output, and we need to compute the gradient of the loss
33
+ with respect to the input.
34
+ """
35
+ vert_pts, normals, res = ctx.saved_tensors
36
+ res = (res.item(), res.item(), res.item())
37
+ # matrix multiplication between dL/dV and dV/dPSR
38
+ # dV/dPSR = - normals
39
+ grad_vert = torch.matmul(dL_dVertex.permute(1, 0, 2), -normals.permute(1, 2, 0))
40
+ grad_grid = point_rasterize(vert_pts, grad_vert.permute(1, 0, 2), res) # b x 1 x res x res x res
41
+
42
+ return grad_grid
43
+
44
+ class PSR2SurfacePoints(torch.autograd.Function):
45
+ @staticmethod
46
+ def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample):
47
+ verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
48
+ verts = verts * 2. - 1. # within the range of [-1, 1]
49
+
50
+
51
+ p_all, n_all, mask_all = [], [], []
52
+
53
+ for i in range(len(poses)):
54
+ pose = poses[i]
55
+ if mask_sample is not None:
56
+ p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size, mask_gt=mask_sample[i])
57
+ else:
58
+ p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size)
59
+
60
+ n_inters = grid_interp(psr_grad[None], (p_inters[None].detach() + 1) / 2).squeeze()
61
+ p_all.append(p_inters)
62
+ n_all.append(n_inters)
63
+ mask_all.append(mask)
64
+ p_inters_all = torch.cat(p_all, dim=0)
65
+ n_inters_all = torch.cat(n_all, dim=0)
66
+ mask_visible = torch.stack(mask_all, dim=0)
67
+
68
+
69
+ res = torch.tensor(psr_grid.detach().shape[2])
70
+ ctx.save_for_backward(p_inters_all, n_inters_all, res)
71
+
72
+ return p_inters_all, mask_visible
73
+
74
+ @staticmethod
75
+ def backward(ctx, dL_dp, dL_dmask):
76
+ pts, pts_n, res = ctx.saved_tensors
77
+ res = (res.item(), res.item(), res.item())
78
+
79
+ # grad from the p_inters via MLP renderer
80
+ grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None])
81
+ grad_grid_pts = point_rasterize((pts[None]+1)/2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res
82
+
83
+ return grad_grid_pts, None, None, None, None, None
84
+
85
+
86
+ # Resnet Blocks from https://github.com/autonomousvision/shape_as_points/blob/12757682f1075d83738b52f96747463b77343caf/src/network/utils.py
87
+ class ResnetBlockFC(nn.Module):
88
+ ''' Fully connected ResNet Block class.
89
+ Args:
90
+ size_in (int): input dimension
91
+ size_out (int): output dimension
92
+ size_h (int): hidden dimension
93
+ '''
94
+
95
+ def __init__(self, size_in, size_out=None, size_h=None, siren=False):
96
+ super().__init__()
97
+ # Attributes
98
+ if size_out is None:
99
+ size_out = size_in
100
+
101
+ if size_h is None:
102
+ size_h = min(size_in, size_out)
103
+
104
+ self.size_in = size_in
105
+ self.size_h = size_h
106
+ self.size_out = size_out
107
+ # Submodules
108
+ self.fc_0 = nn.Linear(size_in, size_h)
109
+ self.fc_1 = nn.Linear(size_h, size_out)
110
+ self.actvn = nn.ReLU()
111
+
112
+ if size_in == size_out:
113
+ self.shortcut = None
114
+ else:
115
+ self.shortcut = nn.Linear(size_in, size_out, bias=False)
116
+ # Initialization
117
+ nn.init.zeros_(self.fc_1.weight)
118
+
119
+ def forward(self, x):
120
+ net = self.fc_0(self.actvn(x))
121
+ dx = self.fc_1(self.actvn(net))
122
+
123
+ if self.shortcut is not None:
124
+ x_s = self.shortcut(x)
125
+ else:
126
+ x_s = x
127
+
128
+ return x_s + dx
129
+
models/SAP/utils.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import io, os, logging, urllib
3
+ import yaml
4
+ import trimesh
5
+ import imageio
6
+ import numbers
7
+ import math
8
+ import numpy as np
9
+ from collections import OrderedDict
10
+ from plyfile import PlyData
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+ from torch.utils import model_zoo
14
+ from skimage import measure, img_as_float32
15
+ from igl import adjacency_matrix, connected_components
16
+
17
+ ##################################################
18
+ # Below are functions for DPSR
19
+
20
+ def fftfreqs(res, dtype=torch.float32, exact=True):
21
+ """
22
+ Helper function to return frequency tensors
23
+ :param res: n_dims int tuple of number of frequency modes
24
+ :return:
25
+ """
26
+
27
+ n_dims = len(res)
28
+ freqs = []
29
+ for dim in range(n_dims - 1):
30
+ r_ = res[dim]
31
+ freq = np.fft.fftfreq(r_, d=1/r_)
32
+ freqs.append(torch.tensor(freq, dtype=dtype))
33
+ r_ = res[-1]
34
+ if exact:
35
+ freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype))
36
+ else:
37
+ freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype))
38
+ omega = torch.meshgrid(freqs)
39
+ omega = list(omega)
40
+ omega = torch.stack(omega, dim=-1)
41
+
42
+ return omega
43
+
44
+ def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag)
45
+ """
46
+ multiply tensor x by i ** deg
47
+ """
48
+ deg %= 4
49
+ if deg == 0:
50
+ res = x
51
+ elif deg == 1:
52
+ res = x[..., [1, 0]]
53
+ res[..., 0] = -res[..., 0]
54
+ elif deg == 2:
55
+ res = -x
56
+ elif deg == 3:
57
+ res = x[..., [1, 0]]
58
+ res[..., 1] = -res[..., 1]
59
+ return res
60
+
61
+ def spec_gaussian_filter(res, sig):
62
+ omega = fftfreqs(res, dtype=torch.float64) # [dim0, dim1, dim2, d]
63
+ dis = torch.sqrt(torch.sum(omega ** 2, dim=-1))
64
+ filter_ = torch.exp(-0.5*((sig*2*dis/res[0])**2)).unsqueeze(-1).unsqueeze(-1)
65
+ filter_.requires_grad = False
66
+
67
+ return filter_
68
+
69
+ def grid_interp(grid, pts, batched=True):
70
+ """
71
+ :param grid: tensor of shape (batch, *size, in_features)
72
+ :param pts: tensor of shape (batch, num_points, dim) within range (0, 1)
73
+ :return values at query points
74
+ """
75
+ if not batched:
76
+ grid = grid.unsqueeze(0)
77
+ pts = pts.unsqueeze(0)
78
+ dim = pts.shape[-1]
79
+ bs = grid.shape[0]
80
+ size = torch.tensor(grid.shape[1:-1]).to(grid.device).type(pts.dtype)
81
+ cubesize = 1.0 / size
82
+
83
+ ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
84
+ ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
85
+ ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
86
+ tmp = torch.tensor([0,1],dtype=torch.long)
87
+ com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
88
+ dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
89
+ ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
90
+ ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
91
+ ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
92
+ # latent code on neighbor nodes
93
+ if dim == 2:
94
+ lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1]] # (batch, num_points, 2**dim, in_features)
95
+ else:
96
+ lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2]] # (batch, num_points, 2**dim, in_features)
97
+
98
+ # weights of neighboring nodes
99
+ xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
100
+ xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
101
+ xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
102
+ pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
103
+ pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
104
+ pos_ = pos_.type(pts.dtype)
105
+ dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
106
+ weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
107
+ query_values = torch.sum(lat * weights.unsqueeze(-1), dim=-2) # (batch, num_points, in_features)
108
+ if not batched:
109
+ query_values = query_values.squeeze(0)
110
+
111
+ return query_values
112
+
113
+ def scatter_to_grid(inds, vals, size):
114
+ """
115
+ Scatter update values into empty tensor of size size.
116
+ :param inds: (#values, dims)
117
+ :param vals: (#values)
118
+ :param size: tuple for size. len(size)=dims
119
+ """
120
+ dims = inds.shape[1]
121
+ assert(inds.shape[0] == vals.shape[0])
122
+ assert(len(size) == dims)
123
+ dev = vals.device
124
+ # result = torch.zeros(*size).view(-1).to(dev).type(vals.dtype) # flatten
125
+ # # flatten inds
126
+ result = torch.zeros(*size, device=dev).view(-1).type(vals.dtype) # flatten
127
+ # flatten inds
128
+ fac = [np.prod(size[i+1:]) for i in range(len(size)-1)] + [1]
129
+ fac = torch.tensor(fac, device=dev).type(inds.dtype)
130
+ inds_fold = torch.sum(inds*fac, dim=-1) # [#values,]
131
+ result.scatter_add_(0, inds_fold, vals)
132
+ result = result.view(*size)
133
+ return result
134
+
135
+ def point_rasterize(pts, vals, size):
136
+ """
137
+ :param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1)
138
+ :param vals: point values, tensor of shape (batch, num_points, features)
139
+ :param size: len(size)=dim tuple for grid size
140
+ :return rasterized values (batch, features, res0, res1, res2)
141
+ """
142
+ dim = pts.shape[-1]
143
+ assert(pts.shape[:2] == vals.shape[:2])
144
+ assert(pts.shape[2] == dim)
145
+ size_list = list(size)
146
+ size = torch.tensor(size).to(pts.device).float()
147
+ cubesize = 1.0 / size
148
+ bs = pts.shape[0]
149
+ nf = vals.shape[-1]
150
+ npts = pts.shape[1]
151
+ dev = pts.device
152
+
153
+ ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim)
154
+ ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around
155
+ ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim)
156
+ tmp = torch.tensor([0,1],dtype=torch.long)
157
+ com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim)
158
+ dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim)
159
+ ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points)
160
+ ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim)
161
+ # ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
162
+ ind_b = torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim)
163
+
164
+ # weights of neighboring nodes
165
+ xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim)
166
+ xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim)
167
+ xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim)
168
+ pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
169
+ pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim)
170
+ pos_ = pos_.type(pts.dtype)
171
+ dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim)
172
+ weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim)
173
+
174
+ ind_b = ind_b.unsqueeze(-1).unsqueeze(-1) # (batch, num_points, 2**dim, 1, 1)
175
+ ind_n = ind_n.unsqueeze(-2) # (batch, num_points, 2**dim, 1, dim)
176
+ ind_f = torch.arange(nf, device=dev).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1)
177
+ # ind_f = torch.arange(nf).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1)
178
+
179
+ ind_b = ind_b.expand(bs, npts, 2**dim, nf, 1)
180
+ ind_n = ind_n.expand(bs, npts, 2**dim, nf, dim).to(dev)
181
+ ind_f = ind_f.expand(bs, npts, 2**dim, nf, 1)
182
+ inds = torch.cat([ind_b, ind_f, ind_n], dim=-1) # (batch, num_points, 2**dim, nf, 1+1+dim)
183
+
184
+ # weighted values
185
+ vals = weights.unsqueeze(-1) * vals.unsqueeze(-2) # (batch, num_points, 2**dim, nf)
186
+
187
+ inds = inds.view(-1, dim+2).permute(1, 0).long() # (1+dim+1, bs*npts*2**dim*nf)
188
+ vals = vals.reshape(-1) # (bs*npts*2**dim*nf)
189
+ tensor_size = [bs, nf] + size_list
190
+ raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf] + size_list)
191
+
192
+ return raster # [batch, nf, res, res, res]
193
+
194
+
195
+
196
+ ##################################################
197
+ # Below are the utilization functions in general
198
+
199
+ class AverageMeter(object):
200
+ """Computes and stores the average and current value"""
201
+ def __init__(self):
202
+ self.reset()
203
+
204
+ def reset(self):
205
+ self.val = 0
206
+ self.n = 0
207
+ self.avg = 0
208
+ self.sum = 0
209
+ self.count = 0
210
+
211
+ def update(self, val, n=1):
212
+ self.val = val
213
+ self.n = n
214
+ self.sum += val * n
215
+ self.count += n
216
+ self.avg = self.sum / self.count
217
+
218
+ @property
219
+ def valcavg(self):
220
+ return self.val.sum().item() / (self.n != 0).sum().item()
221
+
222
+ @property
223
+ def avgcavg(self):
224
+ return self.avg.sum().item() / (self.count != 0).sum().item()
225
+
226
+ def load_model_manual(state_dict, model):
227
+ new_state_dict = OrderedDict()
228
+ is_model_parallel = isinstance(model, torch.nn.DataParallel)
229
+ for k, v in state_dict.items():
230
+ if k.startswith('module.') != is_model_parallel:
231
+ if k.startswith('module.'):
232
+ # remove module
233
+ k = k[7:]
234
+ else:
235
+ # add module
236
+ k = 'module.' + k
237
+
238
+ new_state_dict[k]=v
239
+
240
+ model.load_state_dict(new_state_dict)
241
+
242
+ def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0):
243
+ '''
244
+ Run marching cubes from PSR grid
245
+ '''
246
+ batch_size = psr_grid.shape[0]
247
+ s = psr_grid.shape[-1] # size of psr_grid
248
+ psr_grid_numpy = psr_grid.squeeze().detach().cpu().numpy()
249
+
250
+ if batch_size>1:
251
+ verts, faces, normals = [], [], []
252
+ for i in range(batch_size):
253
+ verts_cur, faces_cur, normals_cur, values = measure.marching_cubes(psr_grid_numpy[i], level=0)
254
+ verts.append(verts_cur)
255
+ faces.append(faces_cur)
256
+ normals.append(normals_cur)
257
+ verts = np.stack(verts, axis = 0)
258
+ faces = np.stack(faces, axis = 0)
259
+ normals = np.stack(normals, axis = 0)
260
+ else:
261
+ try:
262
+ verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy, level=zero_level)
263
+ except:
264
+ verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy)
265
+ if real_scale:
266
+ verts = verts / (s-1) # scale to range [0, 1]
267
+ else:
268
+ verts = verts / s # scale to range [0, 1)
269
+
270
+ if pytorchify:
271
+ device = psr_grid.device
272
+ verts = torch.Tensor(np.ascontiguousarray(verts)).to(device)
273
+ faces = torch.Tensor(np.ascontiguousarray(faces)).to(device)
274
+ normals = torch.Tensor(np.ascontiguousarray(-normals)).to(device)
275
+
276
+ return verts, faces, normals
277
+
278
+ def calc_inters_points(verts, faces, pose, img_size, mask_gt=None):
279
+ verts = verts.squeeze()
280
+ faces = faces.squeeze()
281
+ pix_to_face, w, mask = mesh_rasterization(verts, faces, pose, img_size)
282
+ if mask_gt is not None:
283
+ #! only evaluate within the intersection
284
+ mask = mask & mask_gt
285
+ # find 3D points intesected on the mesh
286
+ if True:
287
+ w_masked = w[mask]
288
+ f_p = faces[pix_to_face[mask]].long() # cooresponding faces for each pixel
289
+ # corresponding vertices for p_closest
290
+ v_a, v_b, v_c = verts[f_p[..., 0]], verts[f_p[..., 1]], verts[f_p[..., 2]]
291
+
292
+ # calculate the intersection point of each pixel and the mesh
293
+ p_inters = w_masked[..., 0, None] * v_a + \
294
+ w_masked[..., 1, None] * v_b + \
295
+ w_masked[..., 2, None] * v_c
296
+ else:
297
+ # backproject ndc to world coordinates using z-buffer
298
+ W, H = img_size[1], img_size[0]
299
+ xy = uv.to(mask.device)[mask]
300
+ x_ndc = 1 - (2*xy[:, 0]) / (W - 1)
301
+ y_ndc = 1 - (2*xy[:, 1]) / (H - 1)
302
+ z = zbuf.squeeze().reshape(H * W)[mask]
303
+ xy_depth = torch.stack((x_ndc, y_ndc, z), dim=1)
304
+
305
+ p_inters = pose.unproject_points(xy_depth, world_coordinates=True)
306
+
307
+ # if there are outlier points, we should remove it
308
+ if (p_inters.max()>1) | (p_inters.min()<-1):
309
+ mask_bound = (p_inters>=-1) & (p_inters<=1)
310
+ mask_bound = (mask_bound.sum(dim=-1)==3)
311
+ mask[mask==True] = mask_bound
312
+ p_inters = p_inters[mask_bound]
313
+ print('!!!!!find outlier!')
314
+
315
+ return p_inters, mask, f_p, w_masked
316
+
317
+ def mesh_rasterization(verts, faces, pose, img_size):
318
+ '''
319
+ Use PyTorch3D to rasterize the mesh given a camera
320
+ '''
321
+ transformed_v = pose.transform_points(verts.detach()) # world -> ndc coordinate system
322
+ if isinstance(pose, PerspectiveCameras):
323
+ transformed_v[..., 2] = 1/transformed_v[..., 2]
324
+ # find p_closest on mesh of each pixel via rasterization
325
+ transformed_mesh = Meshes(verts=[transformed_v], faces=[faces])
326
+ pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
327
+ transformed_mesh,
328
+ image_size=img_size,
329
+ blur_radius=0,
330
+ faces_per_pixel=1,
331
+ perspective_correct=False
332
+ )
333
+ pix_to_face = pix_to_face.reshape(1, -1) # B x reso x reso -> B x (reso x reso)
334
+ mask = pix_to_face.clone() != -1
335
+ mask = mask.squeeze()
336
+ pix_to_face = pix_to_face.squeeze()
337
+ w = bary_coords.reshape(-1, 3)
338
+
339
+ return pix_to_face, w, mask
340
+
341
+ def verts_on_largest_mesh(verts, faces):
342
+ '''
343
+ verts: Numpy array or Torch.Tensor (N, 3)
344
+ faces: Numpy array (N, 3)
345
+ '''
346
+ if torch.is_tensor(faces):
347
+ verts = verts.squeeze().detach().cpu().numpy()
348
+ faces = faces.squeeze().int().detach().cpu().numpy()
349
+
350
+ A = adjacency_matrix(faces)
351
+ num, conn_idx, conn_size = connected_components(A)
352
+ if num == 0:
353
+ v_large, f_large = verts, faces
354
+ else:
355
+ max_idx = conn_size.argmax() # find the index of the largest component
356
+ v_large = verts[conn_idx==max_idx] # keep points on the largest component
357
+
358
+ if True:
359
+ mesh_largest = trimesh.Trimesh(verts, faces)
360
+ connected_comp = mesh_largest.split(only_watertight=False)
361
+ mesh_largest = connected_comp[max_idx]
362
+ v_large, f_large = mesh_largest.vertices, mesh_largest.faces
363
+ v_large = v_large.astype(np.float32)
364
+ return v_large, f_large
365
+
366
+ def update_recursive(dict1, dict2):
367
+ ''' Update two config dictionaries recursively.
368
+
369
+ Args:
370
+ dict1 (dict): first dictionary to be updated
371
+ dict2 (dict): second dictionary which entries should be used
372
+
373
+ '''
374
+ for k, v in dict2.items():
375
+ if k not in dict1:
376
+ dict1[k] = dict()
377
+ if isinstance(v, dict):
378
+ update_recursive(dict1[k], v)
379
+ else:
380
+ dict1[k] = v
381
+
382
+ def scale2onet(p, scale=1.2):
383
+ '''
384
+ Scale the point cloud from SAP to ONet range
385
+ '''
386
+ return (p - 0.5) * scale
387
+
388
+ def update_optimizer(inputs, cfg, epoch, model=None, schedule=None):
389
+ if model is not None:
390
+ if schedule is not None:
391
+ optimizer = torch.optim.Adam([
392
+ {"params": model.parameters(),
393
+ "lr": schedule[0].get_learning_rate(epoch)},
394
+ {"params": inputs,
395
+ "lr": schedule[1].get_learning_rate(epoch)}])
396
+ elif 'lr' in cfg['train']:
397
+ optimizer = torch.optim.Adam([
398
+ {"params": model.parameters(),
399
+ "lr": float(cfg['train']['lr'])},
400
+ {"params": inputs,
401
+ "lr": float(cfg['train']['lr_pcl'])}])
402
+ else:
403
+ raise Exception('no known learning rate')
404
+ else:
405
+ if schedule is not None:
406
+ optimizer = torch.optim.Adam([inputs], lr=schedule[0].get_learning_rate(epoch))
407
+ else:
408
+ optimizer = torch.optim.Adam([inputs], lr=float(cfg['train']['lr_pcl']))
409
+
410
+ return optimizer
411
+
412
+
413
+ def is_url(url):
414
+ scheme = urllib.parse.urlparse(url).scheme
415
+ return scheme in ('http', 'https')
416
+
417
+ def load_url(url):
418
+ '''Load a module dictionary from url.
419
+
420
+ Args:
421
+ url (str): url to saved model
422
+ '''
423
+ print(url)
424
+ print('=> Loading checkpoint from url...')
425
+ state_dict = model_zoo.load_url(url, progress=True)
426
+
427
+ return state_dict
428
+
429
+
430
+ class GaussianSmoothing(nn.Module):
431
+ """
432
+ Apply gaussian smoothing on a
433
+ 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
434
+ in the input using a depthwise convolution.
435
+ Arguments:
436
+ channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well.
437
+ kernel_size (int, sequence): Size of the gaussian kernel.
438
+ sigma (float, sequence): Standard deviation of the gaussian kernel.
439
+ dim (int, optional): The number of dimensions of the data.
440
+ Default value is 2 (spatial).
441
+ """
442
+ def __init__(self, channels, kernel_size, sigma, dim=3):
443
+ super(GaussianSmoothing, self).__init__()
444
+ if isinstance(kernel_size, numbers.Number):
445
+ kernel_size = [kernel_size] * dim
446
+ if isinstance(sigma, numbers.Number):
447
+ sigma = [sigma] * dim
448
+
449
+ # The gaussian kernel is the product of the
450
+ # gaussian function of each dimension.
451
+ kernel = 1
452
+ meshgrids = torch.meshgrid(
453
+ [
454
+ torch.arange(size, dtype=torch.float32)
455
+ for size in kernel_size
456
+ ]
457
+ )
458
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
459
+ mean = (size - 1) / 2
460
+ kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
461
+ torch.exp(-((mgrid - mean) / std) ** 2 / 2)
462
+
463
+ # Make sure sum of values in gaussian kernel equals 1.
464
+ kernel = kernel / torch.sum(kernel)
465
+
466
+ # Reshape to depthwise convolutional weight
467
+ kernel = kernel.view(1, 1, *kernel.size())
468
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
469
+
470
+ self.register_buffer('weight', kernel)
471
+ self.groups = channels
472
+
473
+ if dim == 1:
474
+ self.conv = F.conv1d
475
+ elif dim == 2:
476
+ self.conv = F.conv2d
477
+ elif dim == 3:
478
+ self.conv = F.conv3d
479
+ else:
480
+ raise RuntimeError(
481
+ 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
482
+ )
483
+
484
+ def forward(self, input):
485
+ """
486
+ Apply gaussian filter to input.
487
+ Arguments:
488
+ input (torch.Tensor): Input to apply gaussian filter on.
489
+ Returns:
490
+ filtered (torch.Tensor): Filtered output.
491
+ """
492
+ return self.conv(input, weight=self.weight, groups=self.groups)
493
+
494
+ # Originally from https://github.com/amosgropp/IGR/blob/0db06b1273/code/utils/general.py
495
+ def get_learning_rate_schedules(schedule_specs):
496
+
497
+ schedules = []
498
+
499
+ for key in schedule_specs.keys():
500
+ schedules.append(StepLearningRateSchedule(
501
+ schedule_specs[key]['initial'],
502
+ schedule_specs[key]["interval"],
503
+ schedule_specs[key]["factor"],
504
+ schedule_specs[key]["final"]))
505
+ return schedules
506
+
507
+ class LearningRateSchedule:
508
+ def get_learning_rate(self, epoch):
509
+ pass
510
+ class StepLearningRateSchedule(LearningRateSchedule):
511
+ def __init__(self, initial, interval, factor, final=1e-6):
512
+ self.initial = float(initial)
513
+ self.interval = interval
514
+ self.factor = factor
515
+ self.final = float(final)
516
+
517
+ def get_learning_rate(self, epoch):
518
+ lr = np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6)
519
+ if lr > self.final:
520
+ return lr
521
+ else:
522
+ return self.final
523
+
524
+ def adjust_learning_rate(lr_schedules, optimizer, epoch):
525
+ for i, param_group in enumerate(optimizer.param_groups):
526
+ param_group["lr"] = lr_schedules[i].get_learning_rate(epoch)
models/__init__.py ADDED
File without changes
models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (129 Bytes). View file
 
models/__pycache__/model.cpython-39.pyc ADDED
Binary file (5.56 kB). View file
 
models/model.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from statistics import mean
3
+ from typing import List, Tuple
4
+
5
+ import torch as th
6
+ import pytorch_lightning as pl
7
+ from jaxtyping import Float, Int
8
+ import numpy as np
9
+ from torch_geometric.nn.conv import GATv2Conv
10
+
11
+ from models.SAP.dpsr import DPSR
12
+ from models.SAP.model import PSR2Mesh
13
+
14
+ # Constants
15
+
16
+ th.manual_seed(0)
17
+ np.random.seed(0)
18
+
19
+ BATCH_SIZE = 1 # BS
20
+
21
+ IN_DIM = 1
22
+ OUT_DIM = 1
23
+ LATENT_DIM = 32
24
+
25
+ DROPOUT_PROB = 0.1
26
+ GRID_SIZE = 128
27
+
28
+ def generate_grid_edge_list(gs: int = 128):
29
+ grid_edge_list = []
30
+
31
+ for k in range(gs):
32
+ for j in range(gs):
33
+ for i in range(gs):
34
+ current_idx = i + gs*j + k*gs*gs
35
+ if (i - 1) >= 0:
36
+ grid_edge_list.append([current_idx, i-1 + gs*j + k*gs*gs])
37
+ if (i + 1) < gs:
38
+ grid_edge_list.append([current_idx, i+1 + gs*j + k*gs*gs])
39
+ if (j - 1) >= 0:
40
+ grid_edge_list.append([current_idx, i + gs*(j-1) + k*gs*gs])
41
+ if (j + 1) < gs:
42
+ grid_edge_list.append([current_idx, i + gs*(j+1) + k*gs*gs])
43
+ if (k - 1) >= 0:
44
+ grid_edge_list.append([current_idx, i + gs*j + (k-1)*gs*gs])
45
+ if (k + 1) < gs:
46
+ grid_edge_list.append([current_idx, i + gs*j + (k+1)*gs*gs])
47
+ return grid_edge_list
48
+
49
+ GRID_EDGE_LIST = generate_grid_edge_list(GRID_SIZE)
50
+ GRID_EDGE_LIST = th.tensor(GRID_EDGE_LIST, dtype=th.int)
51
+ GRID_EDGE_LIST = GRID_EDGE_LIST.T
52
+ # GRID_EDGE_LIST = GRID_EDGE_LIST.to(th.device("cuda"))
53
+ GRID_EDGE_LIST.requires_grad = False # Do not forget to delete it if train
54
+
55
+
56
+ class FormOptimizer(th.nn.Module):
57
+ def __init__(self) -> None:
58
+ super().__init__()
59
+
60
+ layers = []
61
+
62
+ self.gconv1 = GATv2Conv(in_channels=IN_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB)
63
+ self.gconv2 = GATv2Conv(in_channels=LATENT_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB)
64
+
65
+ self.actv = th.nn.Sigmoid()
66
+ self.head = th.nn.Linear(in_features=LATENT_DIM, out_features=OUT_DIM)
67
+
68
+ def forward(self,
69
+ field: Float[th.Tensor, "GS GS GS"]) -> Float[th.Tensor, "GS GS GS"]:
70
+ """
71
+ Args:
72
+ field (Tensor [GS, GS, GS]): vertices and normals tensor.
73
+ """
74
+ vertex_features = field.clone()
75
+ vertex_features = vertex_features.reshape(GRID_SIZE*GRID_SIZE*GRID_SIZE, IN_DIM)
76
+
77
+ vertex_features = self.gconv1(x=vertex_features, edge_index=GRID_EDGE_LIST)
78
+ vertex_features = self.gconv2(x=vertex_features, edge_index=GRID_EDGE_LIST)
79
+ field_delta = self.head(self.actv(vertex_features))
80
+
81
+ field_delta = field_delta.reshape(BATCH_SIZE, GRID_SIZE, GRID_SIZE, GRID_SIZE)
82
+ field_delta += field # field_delta carries the gradient
83
+ field_delta = th.clamp(field_delta, min=-0.5, max=0.5)
84
+
85
+ return field_delta
86
+
87
+ class Model(pl.LightningModule):
88
+ def __init__(self):
89
+ super().__init__()
90
+ self.form_optimizer = FormOptimizer()
91
+
92
+ self.dpsr = DPSR([GRID_SIZE, GRID_SIZE, GRID_SIZE], sig=0.0)
93
+ self.field2mesh = PSR2Mesh().apply
94
+
95
+ self.metric = th.nn.MSELoss()
96
+
97
+ self.val_losses = []
98
+ self.train_losses = []
99
+
100
+ def log_h5(self, points, normals):
101
+ dset = self.log_points_file.create_dataset(
102
+ name=str(self.h5_frame),
103
+ shape=points.shape,
104
+ dtype=np.float16,
105
+ compression="gzip")
106
+ dset[:] = points
107
+ dset = self.log_normals_file.create_dataset(
108
+ name=str(self.h5_frame),
109
+ shape=normals.shape,
110
+ dtype=np.float16,
111
+ compression="gzip")
112
+ dset[:] = normals
113
+ self.h5_frame += 1
114
+
115
+ def forward(self,
116
+ v: Float[th.Tensor, "BS N 3"],
117
+ n: Float[th.Tensor, "BS N 3"]) -> Tuple[Float[th.Tensor, "BS N 3"], # v - vertices
118
+ Int[th.Tensor, "2 E"], # f - faces
119
+ Float[th.Tensor, "BS N 3"], # n - vertices normals
120
+ Float[th.Tensor, "BS GR GR GR"]]: # field:
121
+ field = self.dpsr(v, n)
122
+ field = self.form_optimizer(field)
123
+ v, f, n = self.field2mesh(field)
124
+ return v, f, n, field
125
+
126
+ def training_step(self, batch, batch_idx) -> Float[th.Tensor, "1"]:
127
+ vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch
128
+
129
+ mask = th.rand((vertices.shape[1], ), device=th.device("cuda")) < (random.random() / 2.0 + 0.5)
130
+ vertices = vertices[:, mask]
131
+ vertices_normals = vertices_normals[:, mask]
132
+
133
+ vr, fr, nr, field_r = model(vertices, vertices_normals)
134
+
135
+ loss = self.metric(field_r, field_gt)
136
+ train_per_step_loss = loss.item()
137
+ self.train_losses.append(train_per_step_loss)
138
+
139
+ return loss
140
+
141
+ def on_train_epoch_end(self):
142
+ mean_train_per_epoch_loss = mean(self.train_losses)
143
+ self.log("mean_train_per_epoch_loss", mean_train_per_epoch_loss, on_step=False, on_epoch=True)
144
+ self.train_losses = []
145
+
146
+ def validation_step(self, batch, batch_idx):
147
+ vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch
148
+
149
+ vr, fr, nr, field_r = model(vertices, vertices_normals)
150
+
151
+ loss = self.metric(field_r, field_gt)
152
+ val_per_step_loss = loss.item()
153
+ self.val_losses.append(val_per_step_loss)
154
+ return loss
155
+
156
+ def on_validation_epoch_end(self):
157
+ mean_val_per_epoch_loss = mean(self.val_losses)
158
+ self.log("mean_val_per_epoch_loss", mean_val_per_epoch_loss, on_step=False, on_epoch=True)
159
+ self.val_losses = []
160
+
161
+ def configure_optimizers(self):
162
+ optimizer = th.optim.Adam(self.parameters(), lr=LR)
163
+ scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
164
+
165
+ return {
166
+ "optimizer": optimizer,
167
+ "lr_scheduler": {
168
+ "scheduler": scheduler,
169
+ "monitor": "mean_val_per_epoch_loss",
170
+ "interval": "epoch",
171
+ "frequency": 1,
172
+ # If set to `True`, will enforce that the value specified 'monitor'
173
+ # is available when the scheduler is updated, thus stopping
174
+ # training if not found. If set to `False`, it will only produce a warning
175
+ "strict": True,
176
+ # If using the `LearningRateMonitor` callback to monitor the
177
+ # learning rate progress, this keyword can be used to specify
178
+ # a custom logged name
179
+ "name": None,
180
+ }
181
+ }
requirements.txt ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.8.4
2
+ aiosignal==1.3.1
3
+ ansicon==1.89.0
4
+ anyio==3.6.2
5
+ arrow==1.2.3
6
+ asttokens==2.2.1
7
+ async-timeout==4.0.2
8
+ attrs==23.1.0
9
+ backcall==0.2.0
10
+ beautifulsoup4==4.12.2
11
+ blessed==1.20.0
12
+ blinker==1.6.2
13
+ certifi==2023.5.7
14
+ charset-normalizer==3.1.0
15
+ click==8.1.3
16
+ colorama==0.4.6
17
+ comm==0.1.3
18
+ ConfigArgParse==1.5.3
19
+ croniter==1.3.14
20
+ dash==2.9.3
21
+ dash-core-components==2.0.0
22
+ dash-html-components==2.0.0
23
+ dash-table==5.0.0
24
+ dateutils==0.6.12
25
+ debugpy==1.6.7
26
+ decorator==5.1.1
27
+ deepdiff==6.3.0
28
+ executing==1.2.0
29
+ fastapi==0.88.0
30
+ fastjsonschema==2.16.3
31
+ Flask==2.3.2
32
+ frozenlist==1.3.3
33
+ fsspec==2023.5.0
34
+ fvcore==0.1.5.post20221221
35
+ h11==0.14.0
36
+ idna==3.4
37
+ imageio==2.28.1
38
+ importlib-metadata==6.6.0
39
+ inquirer==3.1.3
40
+ iopath==0.1.10
41
+ ipykernel==6.23.1
42
+ ipython==8.13.2
43
+ ipywidgets==8.0.6
44
+ itsdangerous==2.1.2
45
+ jaxtyping==0.2.19
46
+ jedi==0.18.2
47
+ Jinja2==3.1.2
48
+ jinxed==1.2.0
49
+ joblib==1.2.0
50
+ jsonschema==4.17.3
51
+ jupyter_client==8.2.0
52
+ jupyter_core==5.3.0
53
+ jupyterlab-widgets==3.0.7
54
+ lazy_loader==0.2
55
+ libigl==2.4.1
56
+ lightning==2.0.2
57
+ lightning-cloud==0.5.36
58
+ lightning-utilities==0.8.0
59
+ markdown-it-py==2.2.0
60
+ MarkupSafe==2.1.2
61
+ matplotlib-inline==0.1.6
62
+ mdurl==0.1.2
63
+ multidict==6.0.4
64
+ nbformat==5.7.0
65
+ nest-asyncio==1.5.6
66
+ networkx==3.1
67
+ numpy==1.24.3
68
+ open3d==0.17.0
69
+ ordered-set==4.1.0
70
+ packaging==23.1
71
+ parso==0.8.3
72
+ pickleshare==0.7.5
73
+ Pillow==9.5.0
74
+ platformdirs==3.5.1
75
+ plotly==5.14.1
76
+ plyfile==0.9
77
+ portalocker==2.7.0
78
+ prompt-toolkit==3.0.38
79
+ psutil==5.9.5
80
+ pure-eval==0.2.2
81
+ pydantic==1.10.7
82
+ Pygments==2.15.1
83
+ PyJWT==2.7.0
84
+ pyparsing==3.0.9
85
+ pyrsistent==0.19.3
86
+ PySimpleGUI==4.60.4
87
+ python-dateutil==2.8.2
88
+ python-editor==1.0.4
89
+ python-multipart==0.0.6
90
+ pytorch-lightning==2.0.2
91
+ pytz==2023.3
92
+ PyWavelets==1.4.1
93
+ pywin32==306
94
+ PyYAML==6.0
95
+ pyzmq==25.0.2
96
+ readchar==4.0.5
97
+ requests==2.30.0
98
+ rich==13.3.5
99
+ scikit-image==0.20.0
100
+ scikit-learn==1.2.2
101
+ scipy==1.9.1
102
+ six==1.16.0
103
+ sniffio==1.3.0
104
+ soupsieve==2.4.1
105
+ stack-data==0.6.2
106
+ starlette==0.22.0
107
+ starsessions==1.3.0
108
+ tabulate==0.9.0
109
+ tenacity==8.2.2
110
+ termcolor==2.3.0
111
+ threadpoolctl==3.1.0
112
+ tifffile==2023.4.12
113
+ torch==1.13.1+cu116
114
+ torch-cluster==1.6.1+pt113cu116
115
+ torch-geometric==2.3.1
116
+ torch-scatter==2.1.1+pt113cu116
117
+ torch-sparse==0.6.17+pt113cu116
118
+ torch-spline-conv==1.2.2+pt113cu116
119
+ torchaudio==0.13.1
120
+ torchmetrics==0.11.4
121
+ torchvision==0.14.1+cu116
122
+ tornado==6.3.2
123
+ tqdm==4.65.0
124
+ traitlets==5.9.0
125
+ trimesh==3.21.6
126
+ typeguard==4.0.0
127
+ typing_extensions==4.5.0
128
+ urllib3==2.0.2
129
+ uvicorn==0.22.0
130
+ wcwidth==0.2.6
131
+ websocket-client==1.5.1
132
+ websockets==11.0.3
133
+ Werkzeug==2.3.4
134
+ widgetsnbextension==4.0.7
135
+ yacs==0.1.8
136
+ yarl==1.9.2
137
+ zipp==3.15.0
138
+ aiofiles==23.1.0
139
+ aiohttp==3.8.4
140
+ aiosignal==1.3.1
141
+ altair==5.0.0
142
+ ansicon==1.89.0
143
+ anyio==3.6.2
144
+ arrow==1.2.3
145
+ asttokens==2.2.1
146
+ async-timeout==4.0.2
147
+ attrs==23.1.0
148
+ backcall==0.2.0
149
+ beautifulsoup4==4.12.2
150
+ blessed==1.20.0
151
+ blinker==1.6.2
152
+ certifi==2023.5.7
153
+ charset-normalizer==3.1.0
154
+ click==8.1.3
155
+ colorama==0.4.6
156
+ comm==0.1.3
157
+ ConfigArgParse==1.5.3
158
+ contourpy==1.0.7
159
+ croniter==1.3.14
160
+ cycler==0.11.0
161
+ dash==2.9.3
162
+ dash-core-components==2.0.0
163
+ dash-html-components==2.0.0
164
+ dash-table==5.0.0
165
+ dateutils==0.6.12
166
+ debugpy==1.6.7
167
+ decorator==5.1.1
168
+ deepdiff==6.3.0
169
+ executing==1.2.0
170
+ fastapi==0.88.0
171
+ fastjsonschema==2.16.3
172
+ ffmpy==0.3.0
173
+ filelock==3.12.0
174
+ Flask==2.3.2
175
+ fonttools==4.39.4
176
+ frozenlist==1.3.3
177
+ fsspec==2023.5.0
178
+ fvcore==0.1.5.post20221221
179
+ gradio==3.30.0
180
+ gradio_client==0.2.5
181
+ h11==0.14.0
182
+ httpcore==0.17.0
183
+ httpx==0.24.0
184
+ huggingface-hub==0.14.1
185
+ idna==3.4
186
+ imageio==2.28.1
187
+ importlib-metadata==6.6.0
188
+ importlib-resources==5.12.0
189
+ inquirer==3.1.3
190
+ iopath==0.1.10
191
+ ipykernel==6.23.1
192
+ ipython==8.13.2
193
+ ipywidgets==8.0.6
194
+ itsdangerous==2.1.2
195
+ jaxtyping==0.2.19
196
+ jedi==0.18.2
197
+ Jinja2==3.1.2
198
+ jinxed==1.2.0
199
+ joblib==1.2.0
200
+ jsonschema==4.17.3
201
+ jupyter_client==8.2.0
202
+ jupyter_core==5.3.0
203
+ jupyterlab-widgets==3.0.7
204
+ kiwisolver==1.4.4
205
+ lazy_loader==0.2
206
+ libigl==2.4.1
207
+ lightning==2.0.2
208
+ lightning-cloud==0.5.36
209
+ lightning-utilities==0.8.0
210
+ linkify-it-py==2.0.2
211
+ markdown-it-py==2.2.0
212
+ MarkupSafe==2.1.2
213
+ matplotlib==3.7.1
214
+ matplotlib-inline==0.1.6
215
+ mdit-py-plugins==0.3.3
216
+ mdurl==0.1.2
217
+ multidict==6.0.4
218
+ nbformat==5.7.0
219
+ nest-asyncio==1.5.6
220
+ networkx==3.1
221
+ numpy==1.24.3
222
+ open3d==0.17.0
223
+ ordered-set==4.1.0
224
+ orjson==3.8.12
225
+ packaging==23.1
226
+ pandas==2.0.1
227
+ parso==0.8.3
228
+ pickleshare==0.7.5
229
+ Pillow==9.5.0
230
+ platformdirs==3.5.1
231
+ plotly==5.14.1
232
+ plyfile==0.9
233
+ portalocker==2.7.0
234
+ prompt-toolkit==3.0.38
235
+ psutil==5.9.5
236
+ pure-eval==0.2.2
237
+ pydantic==1.10.7
238
+ pydub==0.25.1
239
+ Pygments==2.15.1
240
+ PyJWT==2.7.0
241
+ pyparsing==3.0.9
242
+ pyrsistent==0.19.3
243
+ PySimpleGUI==4.60.4
244
+ python-dateutil==2.8.2
245
+ python-editor==1.0.4
246
+ python-multipart==0.0.6
247
+ pytorch-lightning==2.0.2
248
+ pytz==2023.3
249
+ PyWavelets==1.4.1
250
+ pywin32==306
251
+ PyYAML==6.0
252
+ pyzmq==25.0.2
253
+ readchar==4.0.5
254
+ requests==2.30.0
255
+ rich==13.3.5
256
+ scikit-image==0.20.0
257
+ scikit-learn==1.2.2
258
+ scipy==1.9.1
259
+ semantic-version==2.10.0
260
+ six==1.16.0
261
+ sniffio==1.3.0
262
+ soupsieve==2.4.1
263
+ stack-data==0.6.2
264
+ starlette==0.22.0
265
+ starsessions==1.3.0
266
+ tabulate==0.9.0
267
+ tenacity==8.2.2
268
+ termcolor==2.3.0
269
+ threadpoolctl==3.1.0
270
+ tifffile==2023.4.12
271
+ toolz==0.12.0
272
+ torch==1.13.1+cu116
273
+ torch-cluster==1.6.1+pt113cu116
274
+ torch-geometric==2.3.1
275
+ torch-scatter==2.1.1+pt113cu116
276
+ torch-sparse==0.6.17+pt113cu116
277
+ torch-spline-conv==1.2.2+pt113cu116
278
+ torchaudio==0.13.1
279
+ torchmetrics==0.11.4
280
+ torchvision==0.14.1+cu116
281
+ tornado==6.3.2
282
+ tqdm==4.65.0
283
+ traitlets==5.9.0
284
+ trimesh==3.21.6
285
+ typeguard==4.0.0
286
+ typing_extensions==4.5.0
287
+ tzdata==2023.3
288
+ uc-micro-py==1.0.2
289
+ urllib3==2.0.2
290
+ uvicorn==0.22.0
291
+ wcwidth==0.2.6
292
+ websocket-client==1.5.1
293
+ websockets==11.0.3
294
+ Werkzeug==2.3.4
295
+ widgetsnbextension==4.0.7
296
+ yacs==0.1.8
297
+ yarl==1.9.2
298
+ zipp==3.15.0