|
import gradio as gr |
|
import os |
|
|
|
import pytorch_lightning as pl |
|
import torch as th |
|
import open3d as o3d |
|
import numpy as np |
|
import trimesh as tm |
|
|
|
from models.model import Model |
|
|
|
model = Model() |
|
ckpg = th.load("./checkpoints/epoch=99-step=6000.ckpt") |
|
model.load_state_dict(ckpg["state_dict"]) |
|
|
|
|
|
def process_mesh(mesh_file_name): |
|
|
|
mesh = tm.load_mesh(mesh_file_name) |
|
|
|
v = th.tensor(mesh.vertices, dtype=th.float) |
|
n = th.tensor(mesh.vertex_normals, dtype=th.float) |
|
|
|
with th.no_grad(): |
|
v, f, n, _ = model(v.unsqueeze(0), n.unsqueeze(0)) |
|
|
|
mesh = tm.Trimesh(vertices=v.squeeze(0), |
|
faces=f.squeeze(0), |
|
vertex_normals=n.squeeze(0)) |
|
obj_path = "./sample.obj" |
|
mesh.export(obj_path) |
|
|
|
return obj_path |
|
|
|
|
|
demo = gr.Interface( |
|
fn=process_mesh, |
|
inputs=gr.Model3D(), |
|
outputs=gr.Model3D( |
|
clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"), |
|
examples=[ |
|
[os.path.join(os.path.dirname(__file__), "files\\bunny_n1_hi_50.obj")], |
|
[os.path.join(os.path.dirname(__file__), "files\\child_n2_80.obj")], |
|
[os.path.join(os.path.dirname(__file__), "files\\eight_n3_70.obj")], |
|
], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |