ppsurf / app.py
perler's picture
maybe problem by importing gradio twice
5cfff61
raw
history blame
6.88 kB
#!/usr/bin/env python
from __future__ import annotations
import os
import datetime
import pathlib
import shlex
import subprocess
import sys
from typing import Generator, Optional
import trimesh
import spaces
import gradio as gr
# from model import Model
sys.path.append('TEXTurePaper')
from src.configs.train_config import GuideConfig, LogConfig, TrainConfig
from src.training.trainer import TEXTure
class Model:
def __init__(self):
self.max_num_faces = 100000
def load_config(self, shape_path: str, text: str, seed: int,
guidance_scale: float) -> TrainConfig:
text += ', {} view'
log = LogConfig(exp_name=self.gen_exp_name())
guide = GuideConfig(text=text)
guide.background_img = 'TEXTurePaper/textures/brick_wall.png'
guide.shape_path = 'TEXTurePaper/shapes/spot_triangulated.obj'
config = TrainConfig(log=log, guide=guide)
config.guide.shape_path = shape_path
config.optim.seed = seed
config.guide.guidance_scale = guidance_scale
return config
def gen_exp_name(self) -> str:
now = datetime.datetime.now()
return now.strftime('%Y-%m-%d-%H-%M-%S')
def check_num_faces(self, path: str) -> bool:
with open(path) as f:
lines = [line for line in f.readlines() if line.startswith('f')]
return len(lines) <= self.max_num_faces
def zip_results(self, exp_dir: pathlib.Path) -> str:
mesh_dir = exp_dir / 'mesh'
out_path = f'{exp_dir.name}.zip'
subprocess.run(shlex.split(f'zip -r {out_path} {mesh_dir}'))
return out_path
def run(
self, shape_path: str, text: str, seed: int, guidance_scale: float
) -> Generator[tuple[list[str], Optional[str], Optional[str], str], None,
None]:
if not shape_path.endswith('.obj'):
raise gr.Error('The input file is not .obj file.')
if not self.check_num_faces(shape_path):
raise gr.Error('The number of faces is over 100,000.')
config = self.load_config(shape_path, text, seed, guidance_scale)
trainer = TEXTure(config)
trainer.mesh_model.train()
total_steps = len(trainer.dataloaders['train'])
for step, data in enumerate(trainer.dataloaders['train'], start=1):
trainer.paint_step += 1
trainer.paint_viewpoint(data)
trainer.evaluate(trainer.dataloaders['val'],
trainer.eval_renders_path)
trainer.mesh_model.train()
sample_image_dir = config.log.exp_dir / 'vis' / 'eval'
sample_image_paths = sorted(
sample_image_dir.glob(f'step_{trainer.paint_step:05d}_*.jpg'))
sample_image_paths = [
path.as_posix() for path in sample_image_paths
]
yield sample_image_paths, None, None, f'{step}/{total_steps}'
trainer.mesh_model.change_default_to_median()
save_dir = trainer.exp_path / 'mesh'
save_dir.mkdir(exist_ok=True, parents=True)
trainer.mesh_model.export_mesh(save_dir)
model_path = save_dir / 'mesh.obj'
mesh = trimesh.load(model_path)
mesh_path = save_dir / 'mesh.glb'
mesh.export(mesh_path, file_type='glb')
zip_path = self.zip_results(config.log.exp_dir)
yield sample_image_paths, mesh_path.as_posix(), zip_path, 'Done!'
@spaces.GPU
def main():
DESCRIPTION = '''# [TEXTure](https://github.com/TEXTurePaper/TEXTurePaper)
- This demo only accepts as input `.obj` files with less than 100,000 faces.
- Inference takes about 10 minutes on a T4 GPU.
'''
if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
DESCRIPTION += f'\n<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
model = Model()
with gr.Blocks(css='style.css') as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
input_shape = gr.Model3D(label='Input 3D mesh')
text = gr.Text(label='Text')
seed = gr.Slider(label='Seed',
minimum=0,
maximum=100000,
value=3,
step=1)
guidance_scale = gr.Slider(label='Guidance scale',
minimum=0,
maximum=50,
value=7.5,
step=0.1)
run_button = gr.Button('Run')
with gr.Column():
progress_text = gr.Text(label='Progress')
with gr.Tabs():
with gr.TabItem(label='Images from each viewpoint'):
viewpoint_images = gr.Gallery(show_label=False, columns=4)
with gr.TabItem(label='Result 3D model'):
result_3d_model = gr.Model3D(show_label=False)
with gr.TabItem(label='Output mesh file'):
output_file = gr.File(show_label=False)
with gr.Row():
examples = [
['shapes/dragon1.obj', 'a photo of a dragon', 0, 7.5],
['shapes/dragon2.obj', 'a photo of a dragon', 0, 7.5],
['shapes/eagle.obj', 'a photo of an eagle', 0, 7.5],
['shapes/napoleon.obj', 'a photo of Napoleon Bonaparte', 3, 7.5],
['shapes/nascar.obj', 'A next gen nascar', 2, 10],
]
gr.Examples(examples=examples,
inputs=[
input_shape,
text,
seed,
guidance_scale,
],
outputs=[
result_3d_model,
output_file,
],
cache_examples=False)
run_button.click(fn=model.run,
inputs=[
input_shape,
text,
seed,
guidance_scale,
],
outputs=[
viewpoint_images,
result_3d_model,
output_file,
progress_text,
])
demo.queue(max_size=5).launch(debug=True)
if __name__ == '__main__':
main()