abreza commited on
Commit
f21e8dd
1 Parent(s): 70906c2

Update scripts/utils.py

Browse files
Files changed (1) hide show
  1. scripts/utils.py +156 -232
scripts/utils.py CHANGED
@@ -1,19 +1,19 @@
1
  import torch
2
  import numpy as np
3
  from PIL import Image
4
- import pymeshlab
5
  import pymeshlab as ml
6
- from pymeshlab import PercentageValue
7
  from pytorch3d.renderer import TexturesVertex
8
  from pytorch3d.structures import Meshes
9
  from rembg import new_session, remove
10
- import torch
11
- import torch.nn.functional as F
12
- from typing import List, Tuple
13
- from PIL import Image
14
  import trimesh
 
 
 
 
 
15
 
16
- providers = [
 
17
  ('CUDAExecutionProvider', {
18
  'device_id': 0,
19
  'arena_extend_strategy': 'kSameAsRequested',
@@ -22,298 +22,222 @@ providers = [
22
  })
23
  ]
24
 
25
- session = new_session(providers=providers)
26
-
27
- NEG_PROMPT="sketch, sculpture, hand drawing, outline, single color, NSFW, lowres, bad anatomy,bad hands, text, error, missing fingers, yellow sleeves, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry,(worst quality:1.4),(low quality:1.4)"
28
 
 
29
  def load_mesh_with_trimesh(file_name, file_type=None):
30
- import trimesh
31
- mesh: trimesh.Trimesh = trimesh.load(file_name, file_type=file_type)
32
  if isinstance(mesh, trimesh.Scene):
33
- assert len(mesh.geometry) > 0
34
- # save to obj first and load again to avoid offset issue
35
- from io import BytesIO
36
- with BytesIO() as f:
37
- mesh.export(f, file_type="obj")
38
- f.seek(0)
39
- mesh = trimesh.load(f, file_type="obj")
40
- if isinstance(mesh, trimesh.Scene):
41
- # we lose texture information here
42
- mesh = trimesh.util.concatenate(
43
- tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
44
- for g in mesh.geometry.values()))
45
- assert isinstance(mesh, trimesh.Trimesh)
46
-
47
  vertices = torch.from_numpy(mesh.vertices).T
48
  faces = torch.from_numpy(mesh.faces).T
49
- colors = None
50
- if mesh.visual is not None:
51
- if hasattr(mesh.visual, 'vertex_colors'):
52
- colors = torch.from_numpy(mesh.visual.vertex_colors)[..., :3].T / 255.
53
- if colors is None:
54
- # print("Warning: no vertex color found in mesh! Filling it with gray.")
55
- colors = torch.ones_like(vertices) * 0.5
56
  return vertices, faces, colors
57
 
58
- def meshlab_mesh_to_py3dmesh(mesh: pymeshlab.Mesh) -> Meshes:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  verts = torch.from_numpy(mesh.vertex_matrix()).float()
60
  faces = torch.from_numpy(mesh.face_matrix()).long()
61
  colors = torch.from_numpy(mesh.vertex_color_matrix()[..., :3]).float()
62
  textures = TexturesVertex(verts_features=[colors])
63
  return Meshes(verts=[verts], faces=[faces], textures=textures)
64
 
65
-
66
- def py3dmesh_to_meshlab_mesh(meshes: Meshes) -> pymeshlab.Mesh:
67
  colors_in = F.pad(meshes.textures.verts_features_packed().cpu().float(), [0,1], value=1).numpy().astype(np.float64)
68
- m1 = pymeshlab.Mesh(
69
  vertex_matrix=meshes.verts_packed().cpu().float().numpy().astype(np.float64),
70
  face_matrix=meshes.faces_packed().cpu().long().numpy().astype(np.int32),
71
  v_normals_matrix=meshes.verts_normals_packed().cpu().float().numpy().astype(np.float64),
72
  v_color_matrix=colors_in)
73
- return m1
74
-
75
-
76
- def to_pyml_mesh(vertices,faces):
77
- m1 = pymeshlab.Mesh(
78
- vertex_matrix=vertices.cpu().float().numpy().astype(np.float64),
79
- face_matrix=faces.cpu().long().numpy().astype(np.int32),
80
- )
81
- return m1
82
-
83
-
84
- def to_py3d_mesh(vertices, faces, normals=None):
85
- from pytorch3d.structures import Meshes
86
- from pytorch3d.renderer.mesh.textures import TexturesVertex
87
- mesh = Meshes(verts=[vertices], faces=[faces], textures=None)
88
- if normals is None:
89
- normals = mesh.verts_normals_packed()
90
- # set normals as vertext colors
91
- mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5])
92
- return mesh
93
-
94
-
95
- def from_py3d_mesh(mesh):
96
- return mesh.verts_list()[0], mesh.faces_list()[0], mesh.textures.verts_features_packed()
97
 
 
98
  def rotate_normalmap_by_angle(normal_map: np.ndarray, angle: float):
99
- """
100
- rotate along y-axis
101
- normal_map: np.array, shape=(H, W, 3) in [-1, 1]
102
- angle: float, in degree
103
- """
104
- angle = angle / 180 * np.pi
105
- R = np.array([[np.cos(angle), 0, np.sin(angle)], [0, 1, 0], [-np.sin(angle), 0, np.cos(angle)]])
106
  return np.dot(normal_map.reshape(-1, 3), R.T).reshape(normal_map.shape)
107
 
108
- # from view coord to front view world coord
109
- def rotate_normals(normal_pils, return_types='np', rotate_direction=1) -> np.ndarray: # [0, 255]
110
  n_views = len(normal_pils)
111
  ret = []
112
  for idx, rgba_normal in enumerate(normal_pils):
113
- # rotate normal
114
- normal_np = np.array(rgba_normal)[:, :, :3] / 255 # in [-1, 1]
115
- alpha_np = np.array(rgba_normal)[:, :, 3] / 255 # in [0, 1]
116
- normal_np = normal_np * 2 - 1
117
- normal_np = rotate_normalmap_by_angle(normal_np, rotate_direction * idx * (360 / n_views))
118
- normal_np = (normal_np + 1) / 2
119
- normal_np = normal_np * alpha_np[..., None] # make bg black
120
- rgba_normal_np = np.concatenate([normal_np * 255, alpha_np[:, :, None] * 255] , axis=-1)
121
- if return_types == 'np':
122
- ret.append(rgba_normal_np)
123
- elif return_types == 'pil':
124
- ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8)))
125
- else:
126
- raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}")
127
  return ret
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- def rotate_normalmap_by_angle_torch(normal_map, angle):
131
- """
132
- rotate along y-axis
133
- normal_map: torch.Tensor, shape=(H, W, 3) in [-1, 1], device='cuda'
134
- angle: float, in degree
135
- """
136
- angle = torch.tensor(angle / 180 * np.pi).to(normal_map)
137
- R = torch.tensor([[torch.cos(angle), 0, torch.sin(angle)],
138
- [0, 1, 0],
139
- [-torch.sin(angle), 0, torch.cos(angle)]]).to(normal_map)
140
- return torch.matmul(normal_map.view(-1, 3), R.T).view(normal_map.shape)
141
-
142
- def do_rotate(rgba_normal, angle):
143
- rgba_normal = torch.from_numpy(rgba_normal).float().cuda() / 255
144
- rotated_normal_tensor = rotate_normalmap_by_angle_torch(rgba_normal[..., :3] * 2 - 1, angle)
145
- rotated_normal_tensor = (rotated_normal_tensor + 1) / 2
146
- rotated_normal_tensor = rotated_normal_tensor * rgba_normal[:, :, [3]] # make bg black
147
- rgba_normal_np = torch.cat([rotated_normal_tensor * 255, rgba_normal[:, :, [3]] * 255], dim=-1).cpu().numpy()
148
- return rgba_normal_np
149
-
150
- def rotate_normals_torch(normal_pils, return_types='np', rotate_direction=1):
151
- n_views = len(normal_pils)
152
- ret = []
153
- for idx, rgba_normal in enumerate(normal_pils):
154
- # rotate normal
155
- angle = rotate_direction * idx * (360 / n_views)
156
- rgba_normal_np = do_rotate(np.array(rgba_normal), angle)
157
- if return_types == 'np':
158
- ret.append(rgba_normal_np)
159
- elif return_types == 'pil':
160
- ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8)))
161
- else:
162
- raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}")
163
- return ret
164
-
165
  def change_bkgd(img_pils, new_bkgd=(0., 0., 0.)):
166
- ret = []
167
  new_bkgd = np.array(new_bkgd).reshape(1, 1, 3)
168
- for rgba_img in img_pils:
169
- img_np = np.array(rgba_img)[:, :, :3] / 255
170
- alpha_np = np.array(rgba_img)[:, :, 3] / 255
171
- ori_bkgd = img_np[:1, :1]
172
- # color = ori_color * alpha + bkgd * (1-alpha)
173
- # ori_color = (color - bkgd * (1-alpha)) / alpha
174
- alpha_np_clamp = np.clip(alpha_np, 1e-6, 1) # avoid divide by zero
175
- ori_img_np = (img_np - ori_bkgd * (1 - alpha_np[..., None])) / alpha_np_clamp[..., None]
176
- img_np = np.where(alpha_np[..., None] > 0.05, ori_img_np * alpha_np[..., None] + new_bkgd * (1 - alpha_np[..., None]), new_bkgd)
177
- rgba_img_np = np.concatenate([img_np * 255, alpha_np[..., None] * 255], axis=-1)
178
- ret.append(Image.fromarray(rgba_img_np.astype(np.uint8)))
179
- return ret
180
-
181
- def change_bkgd_to_normal(normal_pils) -> List[Image.Image]:
182
- n_views = len(normal_pils)
183
- ret = []
184
- for idx, rgba_normal in enumerate(normal_pils):
185
- # calcuate background normal
186
- target_bkgd = rotate_normalmap_by_angle(np.array([[[0., 0., 1.]]]), idx * (360 / n_views))
187
- normal_np = np.array(rgba_normal)[:, :, :3] / 255 # in [-1, 1]
188
- alpha_np = np.array(rgba_normal)[:, :, 3] / 255 # in [0, 1]
189
- normal_np = normal_np * 2 - 1
190
- old_bkgd = normal_np[:1,:1]
191
- normal_np[alpha_np > 0.05] = (normal_np[alpha_np > 0.05] - old_bkgd * (1 - alpha_np[alpha_np > 0.05][..., None])) / alpha_np[alpha_np > 0.05][..., None]
192
- normal_np = normal_np * alpha_np[..., None] + target_bkgd * (1 - alpha_np[..., None])
193
- normal_np = (normal_np + 1) / 2
194
- rgba_normal_np = np.concatenate([normal_np * 255, alpha_np[..., None] * 255] , axis=-1)
195
- ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8)))
196
- return ret
197
-
198
-
199
- def fix_vert_color_glb(mesh_path):
200
- from pygltflib import GLTF2, Material, PbrMetallicRoughness
201
- obj1 = GLTF2().load(mesh_path)
202
- obj1.meshes[0].primitives[0].material = 0
203
- obj1.materials.append(Material(
204
- pbrMetallicRoughness = PbrMetallicRoughness(
205
- baseColorFactor = [1.0, 1.0, 1.0, 1.0],
206
- metallicFactor = 0.,
207
- roughnessFactor = 1.0,
208
- ),
209
- emissiveFactor = [0.0, 0.0, 0.0],
210
- doubleSided = True,
211
- ))
212
- obj1.save(mesh_path)
213
 
 
 
 
 
 
 
 
 
 
 
214
 
215
- def srgb_to_linear(c_srgb):
216
- c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4)
217
- return c_linear.clip(0, 1.)
 
 
 
218
 
 
 
 
 
 
 
 
 
219
 
 
220
  def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
221
- # convert from pytorch3d meshes to trimesh mesh
222
  vertices = meshes.verts_packed().cpu().float().numpy()
223
  triangles = meshes.faces_packed().cpu().long().numpy()
224
  np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
 
225
  if save_glb_path.endswith(".glb"):
226
- # rotate 180 along +Y
227
  vertices[:, [0, 2]] = -vertices[:, [0, 2]]
228
 
229
  if apply_sRGB_to_LinearRGB:
230
  np_color = srgb_to_linear(np_color)
231
- assert vertices.shape[0] == np_color.shape[0]
232
- assert np_color.shape[1] == 3
233
- assert 0 <= np_color.min() and np_color.max() <= 1, f"min={np_color.min()}, max={np_color.max()}"
234
  mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
235
  mesh.remove_unreferenced_vertices()
236
- # save mesh
237
  mesh.export(save_glb_path)
 
238
  if save_glb_path.endswith(".glb"):
239
  fix_vert_color_glb(save_glb_path)
240
- print(f"saving to {save_glb_path}")
241
-
242
 
243
- def save_glb_and_video(save_mesh_prefix: str, meshes: Meshes, with_timestamp=True, dist=3.5, azim_offset=180, resolution=512, fov_in_degrees=1 / 1.15, cam_type="ortho", view_padding=60, export_video=True) -> Tuple[str, str]:
244
  import time
245
  if '.' in save_mesh_prefix:
246
  save_mesh_prefix = ".".join(save_mesh_prefix.split('.')[:-1])
247
  if with_timestamp:
248
  save_mesh_prefix = save_mesh_prefix + f"_{int(time.time())}"
249
  ret_mesh = save_mesh_prefix + ".glb"
250
- # optimizied version
251
  save_py3dmesh_with_trimesh_fast(meshes, ret_mesh)
252
  return ret_mesh, None
253
 
 
 
 
254
 
255
- def simple_clean_mesh(pyml_mesh: ml.Mesh, apply_smooth=True, stepsmoothnum=1, apply_sub_divide=False, sub_divide_threshold=0.25):
256
- ms = ml.MeshSet()
257
- ms.add_mesh(pyml_mesh, "cube_mesh")
258
-
259
- if apply_smooth:
260
- ms.apply_filter("apply_coord_laplacian_smoothing", stepsmoothnum=stepsmoothnum, cotangentweight=False)
261
- if apply_sub_divide: # 5s, slow
262
- ms.apply_filter("meshing_repair_non_manifold_vertices")
263
- ms.apply_filter("meshing_repair_non_manifold_edges", method='Remove Faces')
264
- ms.apply_filter("meshing_surface_subdivision_loop", iterations=2, threshold=PercentageValue(sub_divide_threshold))
265
- return meshlab_mesh_to_py3dmesh(ms.current_mesh())
266
-
267
-
268
- def expand2square(pil_img, background_color):
269
- width, height = pil_img.size
270
- if width == height:
271
- return pil_img
272
- elif width > height:
273
- result = Image.new(pil_img.mode, (width, width), background_color)
274
- result.paste(pil_img, (0, (width - height) // 2))
275
- return result
276
- else:
277
- result = Image.new(pil_img.mode, (height, height), background_color)
278
- result.paste(pil_img, ((height - width) // 2, 0))
279
- return result
280
-
281
-
282
- def simple_preprocess(input_image, rembg_session=session, background_color=255):
283
- RES = 2048
284
- input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
285
- if input_image.mode != 'RGBA':
286
- image_rem = input_image.convert('RGBA')
287
- input_image = remove(image_rem, alpha_matting=False, session=rembg_session)
288
-
289
- arr = np.asarray(input_image)
290
- alpha = np.asarray(input_image)[:, :, -1]
291
- x_nonzero = np.nonzero((alpha > 60).sum(axis=1))
292
- y_nonzero = np.nonzero((alpha > 60).sum(axis=0))
293
- x_min = int(x_nonzero[0].min())
294
- y_min = int(y_nonzero[0].min())
295
- x_max = int(x_nonzero[0].max())
296
- y_max = int(y_nonzero[0].max())
297
- arr = arr[x_min: x_max, y_min: y_max]
298
- input_image = Image.fromarray(arr)
299
- input_image = expand2square(input_image, (background_color, background_color, background_color, 0))
300
- return input_image
301
 
302
  def init_target(img_pils, new_bkgd=(0., 0., 0.), device="cuda"):
303
- # Convert the background color to a PyTorch tensor
304
  new_bkgd = torch.tensor(new_bkgd, dtype=torch.float32).view(1, 1, 3).to(device)
305
-
306
- # Convert all images to PyTorch tensors and process them
307
  imgs = torch.stack([torch.from_numpy(np.array(img, dtype=np.float32)) for img in img_pils]).to(device) / 255
308
- img_nps = imgs[..., :3]
309
- alpha_nps = imgs[..., 3]
310
  ori_bkgds = img_nps[:, :1, :1]
311
 
312
- # Avoid divide by zero and calculate the original image
313
  alpha_nps_clamp = torch.clamp(alpha_nps, 1e-6, 1)
314
  ori_img_nps = (img_nps - ori_bkgds * (1 - alpha_nps.unsqueeze(-1))) / alpha_nps_clamp.unsqueeze(-1)
315
  ori_img_nps = torch.clamp(ori_img_nps, 0, 1)
316
  img_nps = torch.where(alpha_nps.unsqueeze(-1) > 0.05, ori_img_nps * alpha_nps.unsqueeze(-1) + new_bkgd * (1 - alpha_nps.unsqueeze(-1)), new_bkgd)
317
 
318
- rgba_img_np = torch.cat([img_nps, alpha_nps.unsqueeze(-1)], dim=-1)
319
- return rgba_img_np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import numpy as np
3
  from PIL import Image
 
4
  import pymeshlab as ml
 
5
  from pytorch3d.renderer import TexturesVertex
6
  from pytorch3d.structures import Meshes
7
  from rembg import new_session, remove
 
 
 
 
8
  import trimesh
9
+ from typing import List, Tuple
10
+ import torch.nn.functional as F
11
+
12
+ # Constants
13
+ NEG_PROMPT = "sketch, sculpture, hand drawing, outline, single color, NSFW, lowres, bad anatomy, bad hands, text, error, missing fingers, yellow sleeves, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, (worst quality:1.4), (low quality:1.4)"
14
 
15
+ # CUDA Configuration
16
+ CUDA_PROVIDERS = [
17
  ('CUDAExecutionProvider', {
18
  'device_id': 0,
19
  'arena_extend_strategy': 'kSameAsRequested',
 
22
  })
23
  ]
24
 
25
+ # Initialize rembg session
26
+ rembg_session = new_session(providers=CUDA_PROVIDERS)
 
27
 
28
+ # Mesh Loading and Conversion Functions
29
  def load_mesh_with_trimesh(file_name, file_type=None):
30
+ mesh = trimesh.load(file_name, file_type=file_type)
 
31
  if isinstance(mesh, trimesh.Scene):
32
+ mesh = _process_trimesh_scene(mesh)
33
+
 
 
 
 
 
 
 
 
 
 
 
 
34
  vertices = torch.from_numpy(mesh.vertices).T
35
  faces = torch.from_numpy(mesh.faces).T
36
+ colors = _get_mesh_colors(mesh)
37
+
 
 
 
 
 
38
  return vertices, faces, colors
39
 
40
+ def _process_trimesh_scene(mesh):
41
+ from io import BytesIO
42
+ with BytesIO() as f:
43
+ mesh.export(f, file_type="obj")
44
+ f.seek(0)
45
+ mesh = trimesh.load(f, file_type="obj")
46
+ if isinstance(mesh, trimesh.Scene):
47
+ mesh = trimesh.util.concatenate(
48
+ tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
49
+ for g in mesh.geometry.values()))
50
+ return mesh
51
+
52
+ def _get_mesh_colors(mesh):
53
+ if mesh.visual is not None and hasattr(mesh.visual, 'vertex_colors'):
54
+ return torch.from_numpy(mesh.visual.vertex_colors)[..., :3].T / 255.
55
+ return torch.ones_like(mesh.vertices.T) * 0.5
56
+
57
+ # Mesh Conversion Functions
58
+ def meshlab_mesh_to_py3dmesh(mesh: ml.Mesh) -> Meshes:
59
  verts = torch.from_numpy(mesh.vertex_matrix()).float()
60
  faces = torch.from_numpy(mesh.face_matrix()).long()
61
  colors = torch.from_numpy(mesh.vertex_color_matrix()[..., :3]).float()
62
  textures = TexturesVertex(verts_features=[colors])
63
  return Meshes(verts=[verts], faces=[faces], textures=textures)
64
 
65
+ def py3dmesh_to_meshlab_mesh(meshes: Meshes) -> ml.Mesh:
 
66
  colors_in = F.pad(meshes.textures.verts_features_packed().cpu().float(), [0,1], value=1).numpy().astype(np.float64)
67
+ return ml.Mesh(
68
  vertex_matrix=meshes.verts_packed().cpu().float().numpy().astype(np.float64),
69
  face_matrix=meshes.faces_packed().cpu().long().numpy().astype(np.int32),
70
  v_normals_matrix=meshes.verts_normals_packed().cpu().float().numpy().astype(np.float64),
71
  v_color_matrix=colors_in)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ # Normal Map Rotation Functions
74
  def rotate_normalmap_by_angle(normal_map: np.ndarray, angle: float):
75
+ angle_rad = np.radians(angle)
76
+ R = np.array([
77
+ [np.cos(angle_rad), 0, np.sin(angle_rad)],
78
+ [0, 1, 0],
79
+ [-np.sin(angle_rad), 0, np.cos(angle_rad)]
80
+ ])
 
81
  return np.dot(normal_map.reshape(-1, 3), R.T).reshape(normal_map.shape)
82
 
83
+ def rotate_normals(normal_pils, return_types='np', rotate_direction=1):
 
84
  n_views = len(normal_pils)
85
  ret = []
86
  for idx, rgba_normal in enumerate(normal_pils):
87
+ normal_np = _process_normal_map(rgba_normal, idx, n_views, rotate_direction)
88
+ ret.append(_format_output(normal_np, return_types))
 
 
 
 
 
 
 
 
 
 
 
 
89
  return ret
90
 
91
+ def _process_normal_map(rgba_normal, idx, n_views, rotate_direction):
92
+ normal_np = np.array(rgba_normal)[:, :, :3] / 255 * 2 - 1
93
+ alpha_np = np.array(rgba_normal)[:, :, 3] / 255
94
+ normal_np = rotate_normalmap_by_angle(normal_np, rotate_direction * idx * (360 / n_views))
95
+ normal_np = (normal_np + 1) / 2 * alpha_np[..., None]
96
+ return np.concatenate([normal_np * 255, alpha_np[:, :, None] * 255], axis=-1)
97
+
98
+ def _format_output(normal_np, return_types):
99
+ if return_types == 'np':
100
+ return normal_np
101
+ elif return_types == 'pil':
102
+ return Image.fromarray(normal_np.astype(np.uint8))
103
+ else:
104
+ raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}")
105
 
106
+ # Background Change Functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  def change_bkgd(img_pils, new_bkgd=(0., 0., 0.)):
 
108
  new_bkgd = np.array(new_bkgd).reshape(1, 1, 3)
109
+ return [_process_image(rgba_img, new_bkgd) for rgba_img in img_pils]
110
+
111
+ def _process_image(rgba_img, new_bkgd):
112
+ img_np = np.array(rgba_img)[:, :, :3] / 255
113
+ alpha_np = np.array(rgba_img)[:, :, 3] / 255
114
+ ori_bkgd = img_np[:1, :1]
115
+ alpha_np_clamp = np.clip(alpha_np, 1e-6, 1)
116
+ ori_img_np = (img_np - ori_bkgd * (1 - alpha_np[..., None])) / alpha_np_clamp[..., None]
117
+ img_np = np.where(alpha_np[..., None] > 0.05, ori_img_np * alpha_np[..., None] + new_bkgd * (1 - alpha_np[..., None]), new_bkgd)
118
+ rgba_img_np = np.concatenate([img_np * 255, alpha_np[..., None] * 255], axis=-1)
119
+ return Image.fromarray(rgba_img_np.astype(np.uint8))
120
+
121
+ # Mesh Cleaning Function
122
+ def simple_clean_mesh(pyml_mesh: ml.Mesh, apply_smooth=True, stepsmoothnum=1, apply_sub_divide=False, sub_divide_threshold=0.25):
123
+ ms = ml.MeshSet()
124
+ ms.add_mesh(pyml_mesh, "cube_mesh")
125
+
126
+ if apply_smooth:
127
+ ms.apply_filter("apply_coord_laplacian_smoothing", stepsmoothnum=stepsmoothnum, cotangentweight=False)
128
+ if apply_sub_divide:
129
+ ms.apply_filter("meshing_repair_non_manifold_vertices")
130
+ ms.apply_filter("meshing_repair_non_manifold_edges", method='Remove Faces')
131
+ ms.apply_filter("meshing_surface_subdivision_loop", iterations=2, threshold=ml.PercentageValue(sub_divide_threshold))
132
+ return meshlab_mesh_to_py3dmesh(ms.current_mesh())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ # Image Processing Functions
135
+ def expand2square(pil_img, background_color):
136
+ width, height = pil_img.size
137
+ if width == height:
138
+ return pil_img
139
+ new_size = max(width, height)
140
+ result = Image.new(pil_img.mode, (new_size, new_size), background_color)
141
+ offset = ((new_size - width) // 2, (new_size - height) // 2)
142
+ result.paste(pil_img, offset)
143
+ return result
144
 
145
+ def simple_preprocess(input_image, rembg_session=rembg_session, background_color=255):
146
+ RES = 2048
147
+ input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
148
+ if input_image.mode != 'RGBA':
149
+ image_rem = input_image.convert('RGBA')
150
+ input_image = remove(image_rem, alpha_matting=False, session=rembg_session)
151
 
152
+ arr = np.asarray(input_image)
153
+ alpha = arr[:, :, -1]
154
+ x_nonzero, y_nonzero = np.nonzero(alpha > 60)
155
+ x_min, x_max = x_nonzero.min(), x_nonzero.max()
156
+ y_min, y_max = y_nonzero.min(), y_nonzero.max()
157
+ arr = arr[x_min:x_max+1, y_min:y_max+1]
158
+ input_image = Image.fromarray(arr)
159
+ return expand2square(input_image, (background_color, background_color, background_color, 0))
160
 
161
+ # Mesh Saving Functions
162
  def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
 
163
  vertices = meshes.verts_packed().cpu().float().numpy()
164
  triangles = meshes.faces_packed().cpu().long().numpy()
165
  np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
166
+
167
  if save_glb_path.endswith(".glb"):
 
168
  vertices[:, [0, 2]] = -vertices[:, [0, 2]]
169
 
170
  if apply_sRGB_to_LinearRGB:
171
  np_color = srgb_to_linear(np_color)
172
+
 
 
173
  mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
174
  mesh.remove_unreferenced_vertices()
 
175
  mesh.export(save_glb_path)
176
+
177
  if save_glb_path.endswith(".glb"):
178
  fix_vert_color_glb(save_glb_path)
179
+ print(f"Saved to {save_glb_path}")
 
180
 
181
+ def save_glb_and_video(save_mesh_prefix: str, meshes: Meshes, with_timestamp=True, **kwargs) -> Tuple[str, str]:
182
  import time
183
  if '.' in save_mesh_prefix:
184
  save_mesh_prefix = ".".join(save_mesh_prefix.split('.')[:-1])
185
  if with_timestamp:
186
  save_mesh_prefix = save_mesh_prefix + f"_{int(time.time())}"
187
  ret_mesh = save_mesh_prefix + ".glb"
 
188
  save_py3dmesh_with_trimesh_fast(meshes, ret_mesh)
189
  return ret_mesh, None
190
 
191
+ # Utility Functions
192
+ def srgb_to_linear(c_srgb):
193
+ return np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4).clip(0, 1.)
194
 
195
+ def fix_vert_color_glb(mesh_path):
196
+ from pygltflib import GLTF2, Material, PbrMetallicRoughness
197
+ obj1 = GLTF2().load(mesh_path)
198
+ obj1.meshes[0].primitives[0].material = 0
199
+ obj1.materials.append(Material(
200
+ pbrMetallicRoughness = PbrMetallicRoughness(
201
+ baseColorFactor = [1.0, 1.0, 1.0, 1.0],
202
+ metallicFactor = 0.,
203
+ roughnessFactor = 1.0,
204
+ ),
205
+ emissiveFactor = [0.0, 0.0, 0.0],
206
+ doubleSided = True,
207
+ ))
208
+ obj1.save(mesh_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  def init_target(img_pils, new_bkgd=(0., 0., 0.), device="cuda"):
 
211
  new_bkgd = torch.tensor(new_bkgd, dtype=torch.float32).view(1, 1, 3).to(device)
 
 
212
  imgs = torch.stack([torch.from_numpy(np.array(img, dtype=np.float32)) for img in img_pils]).to(device) / 255
213
+ img_nps, alpha_nps = imgs[..., :3], imgs[..., 3]
 
214
  ori_bkgds = img_nps[:, :1, :1]
215
 
 
216
  alpha_nps_clamp = torch.clamp(alpha_nps, 1e-6, 1)
217
  ori_img_nps = (img_nps - ori_bkgds * (1 - alpha_nps.unsqueeze(-1))) / alpha_nps_clamp.unsqueeze(-1)
218
  ori_img_nps = torch.clamp(ori_img_nps, 0, 1)
219
  img_nps = torch.where(alpha_nps.unsqueeze(-1) > 0.05, ori_img_nps * alpha_nps.unsqueeze(-1) + new_bkgd * (1 - alpha_nps.unsqueeze(-1)), new_bkgd)
220
 
221
+ return torch.cat([img_nps, alpha_nps.unsqueeze(-1)], dim=-1)
222
+
223
+ def save_obj_and_video(save_mesh_prefix: str, meshes: Meshes, with_timestamp=True, **kwargs) -> Tuple[str, str]:
224
+ if '.' in save_mesh_prefix:
225
+ save_mesh_prefix = ".".join(save_mesh_prefix.split('.')[:-1])
226
+ if with_timestamp:
227
+ save_mesh_prefix = save_mesh_prefix + f"_{int(time.time())}"
228
+ ret_mesh = save_mesh_prefix + ".obj"
229
+
230
+ vertices = meshes.verts_packed().cpu().float().numpy()
231
+ triangles = meshes.faces_packed().cpu().long().numpy()
232
+ np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
233
+
234
+ # Apply sRGB to LinearRGB conversion
235
+ np_color = srgb_to_linear(np_color)
236
+
237
+ mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
238
+ mesh.remove_unreferenced_vertices()
239
+ mesh.export(ret_mesh)
240
+
241
+ print(f"Saved to {ret_mesh}")
242
+
243
+ return ret_mesh, None