perler commited on
Commit
5cfff61
·
1 Parent(s): c647615

maybe problem by importing gradio twice

Browse files
Files changed (1) hide show
  1. app.py +91 -1
app.py CHANGED
@@ -3,11 +3,101 @@
3
  from __future__ import annotations
4
 
5
  import os
 
 
 
 
 
 
6
 
 
7
  import spaces
8
  import gradio as gr
9
 
10
- from model import Model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  @spaces.GPU
 
3
  from __future__ import annotations
4
 
5
  import os
6
+ import datetime
7
+ import pathlib
8
+ import shlex
9
+ import subprocess
10
+ import sys
11
+ from typing import Generator, Optional
12
 
13
+ import trimesh
14
  import spaces
15
  import gradio as gr
16
 
17
+ # from model import Model
18
+
19
+ sys.path.append('TEXTurePaper')
20
+
21
+ from src.configs.train_config import GuideConfig, LogConfig, TrainConfig
22
+ from src.training.trainer import TEXTure
23
+
24
+
25
+ class Model:
26
+ def __init__(self):
27
+ self.max_num_faces = 100000
28
+
29
+ def load_config(self, shape_path: str, text: str, seed: int,
30
+ guidance_scale: float) -> TrainConfig:
31
+ text += ', {} view'
32
+
33
+ log = LogConfig(exp_name=self.gen_exp_name())
34
+ guide = GuideConfig(text=text)
35
+ guide.background_img = 'TEXTurePaper/textures/brick_wall.png'
36
+ guide.shape_path = 'TEXTurePaper/shapes/spot_triangulated.obj'
37
+ config = TrainConfig(log=log, guide=guide)
38
+
39
+ config.guide.shape_path = shape_path
40
+ config.optim.seed = seed
41
+ config.guide.guidance_scale = guidance_scale
42
+ return config
43
+
44
+ def gen_exp_name(self) -> str:
45
+ now = datetime.datetime.now()
46
+ return now.strftime('%Y-%m-%d-%H-%M-%S')
47
+
48
+ def check_num_faces(self, path: str) -> bool:
49
+ with open(path) as f:
50
+ lines = [line for line in f.readlines() if line.startswith('f')]
51
+ return len(lines) <= self.max_num_faces
52
+
53
+ def zip_results(self, exp_dir: pathlib.Path) -> str:
54
+ mesh_dir = exp_dir / 'mesh'
55
+ out_path = f'{exp_dir.name}.zip'
56
+ subprocess.run(shlex.split(f'zip -r {out_path} {mesh_dir}'))
57
+ return out_path
58
+
59
+ def run(
60
+ self, shape_path: str, text: str, seed: int, guidance_scale: float
61
+ ) -> Generator[tuple[list[str], Optional[str], Optional[str], str], None,
62
+ None]:
63
+ if not shape_path.endswith('.obj'):
64
+ raise gr.Error('The input file is not .obj file.')
65
+ if not self.check_num_faces(shape_path):
66
+ raise gr.Error('The number of faces is over 100,000.')
67
+
68
+ config = self.load_config(shape_path, text, seed, guidance_scale)
69
+ trainer = TEXTure(config)
70
+
71
+ trainer.mesh_model.train()
72
+
73
+ total_steps = len(trainer.dataloaders['train'])
74
+ for step, data in enumerate(trainer.dataloaders['train'], start=1):
75
+ trainer.paint_step += 1
76
+ trainer.paint_viewpoint(data)
77
+ trainer.evaluate(trainer.dataloaders['val'],
78
+ trainer.eval_renders_path)
79
+ trainer.mesh_model.train()
80
+
81
+ sample_image_dir = config.log.exp_dir / 'vis' / 'eval'
82
+ sample_image_paths = sorted(
83
+ sample_image_dir.glob(f'step_{trainer.paint_step:05d}_*.jpg'))
84
+ sample_image_paths = [
85
+ path.as_posix() for path in sample_image_paths
86
+ ]
87
+ yield sample_image_paths, None, None, f'{step}/{total_steps}'
88
+
89
+ trainer.mesh_model.change_default_to_median()
90
+
91
+ save_dir = trainer.exp_path / 'mesh'
92
+ save_dir.mkdir(exist_ok=True, parents=True)
93
+ trainer.mesh_model.export_mesh(save_dir)
94
+ model_path = save_dir / 'mesh.obj'
95
+ mesh = trimesh.load(model_path)
96
+ mesh_path = save_dir / 'mesh.glb'
97
+ mesh.export(mesh_path, file_type='glb')
98
+
99
+ zip_path = self.zip_results(config.log.exp_dir)
100
+ yield sample_image_paths, mesh_path.as_posix(), zip_path, 'Done!'
101
 
102
 
103
  @spaces.GPU