MVV's picture
Upload 19 files
583d0a8
raw
history blame
No virus
7.31 kB
import random
from statistics import mean
from typing import List, Tuple
import torch as th
import pytorch_lightning as pl
from jaxtyping import Float, Int
import numpy as np
from torch_geometric.nn.conv import GATv2Conv
from models.SAP.dpsr import DPSR
from models.SAP.model import PSR2Mesh
# Constants
th.manual_seed(0)
np.random.seed(0)
BATCH_SIZE = 1 # BS
IN_DIM = 1
OUT_DIM = 1
LATENT_DIM = 32
DROPOUT_PROB = 0.1
GRID_SIZE = 128
def generate_grid_edge_list(gs: int = 128):
grid_edge_list = []
for k in range(gs):
for j in range(gs):
for i in range(gs):
current_idx = i + gs*j + k*gs*gs
if (i - 1) >= 0:
grid_edge_list.append([current_idx, i-1 + gs*j + k*gs*gs])
if (i + 1) < gs:
grid_edge_list.append([current_idx, i+1 + gs*j + k*gs*gs])
if (j - 1) >= 0:
grid_edge_list.append([current_idx, i + gs*(j-1) + k*gs*gs])
if (j + 1) < gs:
grid_edge_list.append([current_idx, i + gs*(j+1) + k*gs*gs])
if (k - 1) >= 0:
grid_edge_list.append([current_idx, i + gs*j + (k-1)*gs*gs])
if (k + 1) < gs:
grid_edge_list.append([current_idx, i + gs*j + (k+1)*gs*gs])
return grid_edge_list
GRID_EDGE_LIST = generate_grid_edge_list(GRID_SIZE)
GRID_EDGE_LIST = th.tensor(GRID_EDGE_LIST, dtype=th.int)
GRID_EDGE_LIST = GRID_EDGE_LIST.T
# GRID_EDGE_LIST = GRID_EDGE_LIST.to(th.device("cuda"))
GRID_EDGE_LIST.requires_grad = False # Do not forget to delete it if train
class FormOptimizer(th.nn.Module):
def __init__(self) -> None:
super().__init__()
layers = []
self.gconv1 = GATv2Conv(in_channels=IN_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB)
self.gconv2 = GATv2Conv(in_channels=LATENT_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB)
self.actv = th.nn.Sigmoid()
self.head = th.nn.Linear(in_features=LATENT_DIM, out_features=OUT_DIM)
def forward(self,
field: Float[th.Tensor, "GS GS GS"]) -> Float[th.Tensor, "GS GS GS"]:
"""
Args:
field (Tensor [GS, GS, GS]): vertices and normals tensor.
"""
vertex_features = field.clone()
vertex_features = vertex_features.reshape(GRID_SIZE*GRID_SIZE*GRID_SIZE, IN_DIM)
vertex_features = self.gconv1(x=vertex_features, edge_index=GRID_EDGE_LIST)
vertex_features = self.gconv2(x=vertex_features, edge_index=GRID_EDGE_LIST)
field_delta = self.head(self.actv(vertex_features))
field_delta = field_delta.reshape(BATCH_SIZE, GRID_SIZE, GRID_SIZE, GRID_SIZE)
field_delta += field # field_delta carries the gradient
field_delta = th.clamp(field_delta, min=-0.5, max=0.5)
return field_delta
class Model(pl.LightningModule):
def __init__(self):
super().__init__()
self.form_optimizer = FormOptimizer()
self.dpsr = DPSR([GRID_SIZE, GRID_SIZE, GRID_SIZE], sig=0.0)
self.field2mesh = PSR2Mesh().apply
self.metric = th.nn.MSELoss()
self.val_losses = []
self.train_losses = []
def log_h5(self, points, normals):
dset = self.log_points_file.create_dataset(
name=str(self.h5_frame),
shape=points.shape,
dtype=np.float16,
compression="gzip")
dset[:] = points
dset = self.log_normals_file.create_dataset(
name=str(self.h5_frame),
shape=normals.shape,
dtype=np.float16,
compression="gzip")
dset[:] = normals
self.h5_frame += 1
def forward(self,
v: Float[th.Tensor, "BS N 3"],
n: Float[th.Tensor, "BS N 3"]) -> Tuple[Float[th.Tensor, "BS N 3"], # v - vertices
Int[th.Tensor, "2 E"], # f - faces
Float[th.Tensor, "BS N 3"], # n - vertices normals
Float[th.Tensor, "BS GR GR GR"]]: # field:
field = self.dpsr(v, n)
field = self.form_optimizer(field)
v, f, n = self.field2mesh(field)
return v, f, n, field
def training_step(self, batch, batch_idx) -> Float[th.Tensor, "1"]:
vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch
mask = th.rand((vertices.shape[1], ), device=th.device("cuda")) < (random.random() / 2.0 + 0.5)
vertices = vertices[:, mask]
vertices_normals = vertices_normals[:, mask]
vr, fr, nr, field_r = model(vertices, vertices_normals)
loss = self.metric(field_r, field_gt)
train_per_step_loss = loss.item()
self.train_losses.append(train_per_step_loss)
return loss
def on_train_epoch_end(self):
mean_train_per_epoch_loss = mean(self.train_losses)
self.log("mean_train_per_epoch_loss", mean_train_per_epoch_loss, on_step=False, on_epoch=True)
self.train_losses = []
def validation_step(self, batch, batch_idx):
vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch
vr, fr, nr, field_r = model(vertices, vertices_normals)
loss = self.metric(field_r, field_gt)
val_per_step_loss = loss.item()
self.val_losses.append(val_per_step_loss)
return loss
def on_validation_epoch_end(self):
mean_val_per_epoch_loss = mean(self.val_losses)
self.log("mean_val_per_epoch_loss", mean_val_per_epoch_loss, on_step=False, on_epoch=True)
self.val_losses = []
def configure_optimizers(self):
optimizer = th.optim.Adam(self.parameters(), lr=LR)
scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "mean_val_per_epoch_loss",
"interval": "epoch",
"frequency": 1,
# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
# training if not found. If set to `False`, it will only produce a warning
"strict": True,
# If using the `LearningRateMonitor` callback to monitor the
# learning rate progress, this keyword can be used to specify
# a custom logged name
"name": None,
}
}