Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
from __future__ import annotations | |
import os | |
import datetime | |
import pathlib | |
import shlex | |
import subprocess | |
import sys | |
from typing import Generator, Optional | |
import trimesh | |
import spaces | |
import gradio as gr | |
# from model import Model | |
sys.path.append('TEXTurePaper') | |
from src.configs.train_config import GuideConfig, LogConfig, TrainConfig | |
from src.training.trainer import TEXTure | |
class Model: | |
def __init__(self): | |
self.max_num_faces = 100000 | |
def load_config(self, shape_path: str, text: str, seed: int, | |
guidance_scale: float) -> TrainConfig: | |
text += ', {} view' | |
log = LogConfig(exp_name=self.gen_exp_name()) | |
guide = GuideConfig(text=text) | |
guide.background_img = 'TEXTurePaper/textures/brick_wall.png' | |
guide.shape_path = 'TEXTurePaper/shapes/spot_triangulated.obj' | |
config = TrainConfig(log=log, guide=guide) | |
config.guide.shape_path = shape_path | |
config.optim.seed = seed | |
config.guide.guidance_scale = guidance_scale | |
return config | |
def gen_exp_name(self) -> str: | |
now = datetime.datetime.now() | |
return now.strftime('%Y-%m-%d-%H-%M-%S') | |
def check_num_faces(self, path: str) -> bool: | |
with open(path) as f: | |
lines = [line for line in f.readlines() if line.startswith('f')] | |
return len(lines) <= self.max_num_faces | |
def zip_results(self, exp_dir: pathlib.Path) -> str: | |
mesh_dir = exp_dir / 'mesh' | |
out_path = f'{exp_dir.name}.zip' | |
subprocess.run(shlex.split(f'zip -r {out_path} {mesh_dir}')) | |
return out_path | |
def run( | |
self, shape_path: str, text: str, seed: int, guidance_scale: float | |
) -> Generator[tuple[list[str], Optional[str], Optional[str], str], None, | |
None]: | |
if not shape_path.endswith('.obj'): | |
raise gr.Error('The input file is not .obj file.') | |
if not self.check_num_faces(shape_path): | |
raise gr.Error('The number of faces is over 100,000.') | |
config = self.load_config(shape_path, text, seed, guidance_scale) | |
trainer = TEXTure(config) | |
trainer.mesh_model.train() | |
total_steps = len(trainer.dataloaders['train']) | |
for step, data in enumerate(trainer.dataloaders['train'], start=1): | |
trainer.paint_step += 1 | |
trainer.paint_viewpoint(data) | |
trainer.evaluate(trainer.dataloaders['val'], | |
trainer.eval_renders_path) | |
trainer.mesh_model.train() | |
sample_image_dir = config.log.exp_dir / 'vis' / 'eval' | |
sample_image_paths = sorted( | |
sample_image_dir.glob(f'step_{trainer.paint_step:05d}_*.jpg')) | |
sample_image_paths = [ | |
path.as_posix() for path in sample_image_paths | |
] | |
yield sample_image_paths, None, None, f'{step}/{total_steps}' | |
trainer.mesh_model.change_default_to_median() | |
save_dir = trainer.exp_path / 'mesh' | |
save_dir.mkdir(exist_ok=True, parents=True) | |
trainer.mesh_model.export_mesh(save_dir) | |
model_path = save_dir / 'mesh.obj' | |
mesh = trimesh.load(model_path) | |
mesh_path = save_dir / 'mesh.glb' | |
mesh.export(mesh_path, file_type='glb') | |
zip_path = self.zip_results(config.log.exp_dir) | |
yield sample_image_paths, mesh_path.as_posix(), zip_path, 'Done!' | |
def main(): | |
DESCRIPTION = '''# [TEXTure](https://github.com/TEXTurePaper/TEXTurePaper) | |
- This demo only accepts as input `.obj` files with less than 100,000 faces. | |
- Inference takes about 10 minutes on a T4 GPU. | |
''' | |
if (SPACE_ID := os.getenv('SPACE_ID')) is not None: | |
DESCRIPTION += f'\n<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>' | |
model = Model() | |
with gr.Blocks(css='style.css') as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
input_shape = gr.Model3D(label='Input 3D mesh') | |
text = gr.Text(label='Text') | |
seed = gr.Slider(label='Seed', | |
minimum=0, | |
maximum=100000, | |
value=3, | |
step=1) | |
guidance_scale = gr.Slider(label='Guidance scale', | |
minimum=0, | |
maximum=50, | |
value=7.5, | |
step=0.1) | |
run_button = gr.Button('Run') | |
with gr.Column(): | |
progress_text = gr.Text(label='Progress') | |
with gr.Tabs(): | |
with gr.TabItem(label='Images from each viewpoint'): | |
viewpoint_images = gr.Gallery(show_label=False, columns=4) | |
with gr.TabItem(label='Result 3D model'): | |
result_3d_model = gr.Model3D(show_label=False) | |
with gr.TabItem(label='Output mesh file'): | |
output_file = gr.File(show_label=False) | |
with gr.Row(): | |
examples = [ | |
['shapes/dragon1.obj', 'a photo of a dragon', 0, 7.5], | |
['shapes/dragon2.obj', 'a photo of a dragon', 0, 7.5], | |
['shapes/eagle.obj', 'a photo of an eagle', 0, 7.5], | |
['shapes/napoleon.obj', 'a photo of Napoleon Bonaparte', 3, 7.5], | |
['shapes/nascar.obj', 'A next gen nascar', 2, 10], | |
] | |
gr.Examples(examples=examples, | |
inputs=[ | |
input_shape, | |
text, | |
seed, | |
guidance_scale, | |
], | |
outputs=[ | |
result_3d_model, | |
output_file, | |
], | |
cache_examples=False) | |
run_button.click(fn=model.run, | |
inputs=[ | |
input_shape, | |
text, | |
seed, | |
guidance_scale, | |
], | |
outputs=[ | |
viewpoint_images, | |
result_3d_model, | |
output_file, | |
progress_text, | |
]) | |
demo.queue(max_size=5).launch(debug=True) | |
if __name__ == '__main__': | |
main() | |