kyleleey commited on
Commit
f09d510
1 Parent(s): d7c4a03

init version of app

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +79 -24
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
7
  python_version: 3.9.13
8
- sdk_version: 4.12.0
9
  app_file: app.py
10
  pinned: false
11
  license: cc-by-nc-sa-4.0
 
5
  colorTo: green
6
  sdk: gradio
7
  python_version: 3.9.13
8
+ sdk_version: 3.50.2
9
  app_file: app.py
10
  pinned: false
11
  license: cc-by-nc-sa-4.0
app.py CHANGED
@@ -98,7 +98,7 @@ def expand2square(pil_img, background_color):
98
  return result
99
 
100
 
101
- def preprocess(predictor, input_image, chk_group=None, segment=True):
102
  RES = 1024
103
  input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
104
  if chk_group is not None:
@@ -403,13 +403,76 @@ def create_bones_scene(bones, joint_color=[66, 91, 140], bone_color=[119, 144, 1
403
  return mesh
404
 
405
 
406
- def run_pipeline(model_items, cfgs, input_img, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  epoch = 999
408
  total_iter = 999999
409
  model = model_items[0]
410
  memory_bank = model_items[1]
411
  memory_bank_keys = model_items[2]
412
 
 
 
413
  input_image = torch.stack([torchvision.transforms.ToTensor()(input_img)], dim=0).to(device)
414
 
415
  with torch.no_grad():
@@ -455,7 +518,7 @@ def run_pipeline(model_items, cfgs, input_img, device):
455
  gray_light = FixedDirectionLight(direction=torch.FloatTensor([0, 0, 1]).to(device), amb=0.2, diff=0.7)
456
 
457
  image_pred, mask_pred, _, _, _, shading = model.render(
458
- shape, texture_pred, mvp, w2c, campos, 256, background=model.background_mode,
459
  im_features=im_features, light=gray_light, prior_shape=prior_shape, render_mode='diffuse',
460
  render_flow=False, dino_pred=None, im_features_map=im_features_map
461
  )
@@ -469,7 +532,7 @@ def run_pipeline(model_items, cfgs, input_img, device):
469
  nv_meshes = make_mesh(verts=bones_meshes.verts_padded(), faces=bones_meshes.faces_padded()[0:1],
470
  uvs=bones_meshes.textures.verts_uvs_padded(), uv_idx=bones_meshes.textures.faces_uvs_padded()[0:1],
471
  material=material_texture.Texture2D(bones_meshes.textures.maps_padded()))
472
- buffers = render_mesh(dr.RasterizeGLContext(), nv_meshes, mvp, w2c, campos, nv_meshes.material, lgt=gray_light, feat=im_features, dino_pred=None, resolution=256, bsdf="diffuse")
473
 
474
  shaded = buffers["shaded"].permute(0, 3, 1, 2)
475
  bone_image = shaded[:, :3, :, :]
@@ -481,20 +544,10 @@ def run_pipeline(model_items, cfgs, input_img, device):
481
  mesh_image = save_images(shading, mask_pred)
482
  mesh_bones_image = save_images(image_with_bones, mask_final)
483
 
484
- final_shape = shape.clone()
485
- prior_shape = prior_shape.clone()
486
-
487
- final_mesh_tri = trimesh.Trimesh(
488
- vertices=final_shape.v_pos[0].detach().cpu().numpy(),
489
- faces=final_shape.t_pos_idx[0].detach().cpu().numpy(),
490
- process=False,
491
- maintain_order=True)
492
- prior_mesh_tri = trimesh.Trimesh(
493
- vertices=prior_shape.v_pos[0].detach().cpu().numpy(),
494
- faces=prior_shape.t_pos_idx[0].detach().cpu().numpy(),
495
- process=False,
496
- maintain_order=True)
497
 
 
498
 
499
 
500
  def run_demo():
@@ -582,7 +635,6 @@ def run_demo():
582
  with gr.Column():
583
  input_processing = gr.CheckboxGroup(['Use SAM to center animal'],
584
  label='Input Image Preprocessing',
585
- value=['Use SAM to center animal'],
586
  info='untick this, if animal is already centered, e.g. in example images')
587
  # with gr.Column():
588
  # output_processing = gr.CheckboxGroup(['Background Removal'], label='Output Image Postprocessing', value=[])
@@ -599,23 +651,26 @@ def run_demo():
599
  # with gr.Column():
600
  # crop_size = gr.Number(192, label='Crop size')
601
  # crop_size = 192
602
- run_btn = gr.Button('Generate', variant='primary', interactive=True)
603
  with gr.Row():
604
  view_1 = gr.Image(interactive=False, height=256, show_label=False)
605
  view_2 = gr.Image(interactive=False, height=256, show_label=False)
606
  with gr.Row():
607
- shape_1 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Reconstructed Model")
608
- shape_2 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Bank Base Shape Model")
 
 
 
609
 
610
  run_btn.click(fn=partial(preprocess, predictor),
611
  inputs=[input_image, input_processing],
612
  outputs=[processed_image_highres, processed_image], queue=True
613
  ).success(fn=partial(run_pipeline, model_items, model_cfgs),
614
- inputs=[processed_image, device],
615
- outputs=[view_1, view_2, shape_1, shape_2]
616
  )
617
  demo.queue().launch(share=True, max_threads=80)
618
- # _, local_url, share_url = demo.launch(share=True, server_name="0.0.0.0", server_port=23425)
619
  # print('local_url: ', local_url)
620
 
621
 
 
98
  return result
99
 
100
 
101
+ def preprocess(predictor, input_image, chk_group=None, segment=False):
102
  RES = 1024
103
  input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
104
  if chk_group is not None:
 
403
  return mesh
404
 
405
 
406
+ def save_mesh(mesh, file_path):
407
+ obj_file = file_path
408
+ idx = 0
409
+ print("Writing mesh: ", obj_file)
410
+ with open(obj_file, "w") as f:
411
+ # f.write(f"mtllib {fname}.mtl\n")
412
+ f.write("g default\n")
413
+
414
+ v_pos = mesh.v_pos[idx].detach().cpu().numpy() if mesh.v_pos is not None else None
415
+ v_nrm = mesh.v_nrm[idx].detach().cpu().numpy() if mesh.v_nrm is not None else None
416
+ v_tex = mesh.v_tex[idx].detach().cpu().numpy() if mesh.v_tex is not None else None
417
+
418
+ t_pos_idx = mesh.t_pos_idx[0].detach().cpu().numpy() if mesh.t_pos_idx is not None else None
419
+ t_nrm_idx = mesh.t_nrm_idx[0].detach().cpu().numpy() if mesh.t_nrm_idx is not None else None
420
+ t_tex_idx = mesh.t_tex_idx[0].detach().cpu().numpy() if mesh.t_tex_idx is not None else None
421
+
422
+ print(" writing %d vertices" % len(v_pos))
423
+ for v in v_pos:
424
+ f.write('v {} {} {} \n'.format(v[0], v[1], v[2]))
425
+
426
+ if v_nrm is not None:
427
+ print(" writing %d normals" % len(v_nrm))
428
+ assert(len(t_pos_idx) == len(t_nrm_idx))
429
+ for v in v_nrm:
430
+ f.write('vn {} {} {}\n'.format(v[0], v[1], v[2]))
431
+
432
+ # faces
433
+ f.write("s 1 \n")
434
+ f.write("g pMesh1\n")
435
+ f.write("usemtl defaultMat\n")
436
+
437
+ # Write faces
438
+ print(" writing %d faces" % len(t_pos_idx))
439
+ for i in range(len(t_pos_idx)):
440
+ f.write("f ")
441
+ for j in range(3):
442
+ f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1)))
443
+ f.write("\n")
444
+
445
+
446
+ def process_mesh(shape, name):
447
+ mesh = shape.clone()
448
+ output_glb = f'./{name}.glb'
449
+ output_obj = f'./{name}.obj'
450
+
451
+ # save the obj file for download
452
+ save_mesh(mesh, output_obj)
453
+
454
+ # save the glb for visualize
455
+ mesh_tri = trimesh.Trimesh(
456
+ vertices=mesh.v_pos[0].detach().cpu().numpy(),
457
+ faces=mesh.t_pos_idx[0][..., [2,1,0]].detach().cpu().numpy(),
458
+ process=False,
459
+ maintain_order=True
460
+ )
461
+ mesh_tri.visual.vertex_colors = (mesh.v_nrm[0][..., [2,1,0]].detach().cpu().numpy() + 1.0) * 0.5 * 255.0
462
+ mesh_tri.export(file_obj=output_glb)
463
+
464
+ return output_glb, output_obj
465
+
466
+
467
+ def run_pipeline(model_items, cfgs, input_img):
468
  epoch = 999
469
  total_iter = 999999
470
  model = model_items[0]
471
  memory_bank = model_items[1]
472
  memory_bank_keys = model_items[2]
473
 
474
+ device = f'cuda:{_GPU_ID}'
475
+
476
  input_image = torch.stack([torchvision.transforms.ToTensor()(input_img)], dim=0).to(device)
477
 
478
  with torch.no_grad():
 
518
  gray_light = FixedDirectionLight(direction=torch.FloatTensor([0, 0, 1]).to(device), amb=0.2, diff=0.7)
519
 
520
  image_pred, mask_pred, _, _, _, shading = model.render(
521
+ shape, texture_pred, mvp, w2c, campos, (256, 256), background=model.background_mode,
522
  im_features=im_features, light=gray_light, prior_shape=prior_shape, render_mode='diffuse',
523
  render_flow=False, dino_pred=None, im_features_map=im_features_map
524
  )
 
532
  nv_meshes = make_mesh(verts=bones_meshes.verts_padded(), faces=bones_meshes.faces_padded()[0:1],
533
  uvs=bones_meshes.textures.verts_uvs_padded(), uv_idx=bones_meshes.textures.faces_uvs_padded()[0:1],
534
  material=material_texture.Texture2D(bones_meshes.textures.maps_padded()))
535
+ buffers = render_mesh(dr.RasterizeGLContext(), nv_meshes, mvp, w2c, campos, nv_meshes.material, lgt=gray_light, feat=im_features, dino_pred=None, resolution=(256,256), bsdf="diffuse")
536
 
537
  shaded = buffers["shaded"].permute(0, 3, 1, 2)
538
  bone_image = shaded[:, :3, :, :]
 
544
  mesh_image = save_images(shading, mask_pred)
545
  mesh_bones_image = save_images(image_with_bones, mask_final)
546
 
547
+ shape_glb, shape_obj = process_mesh(shape, 'reconstruced_shape')
548
+ base_shape_glb, base_shape_obj = process_mesh(prior_shape, 'reconstructed_base_shape')
 
 
 
 
 
 
 
 
 
 
 
549
 
550
+ return mesh_image, mesh_bones_image, shape_glb, shape_obj, base_shape_glb, base_shape_obj
551
 
552
 
553
  def run_demo():
 
635
  with gr.Column():
636
  input_processing = gr.CheckboxGroup(['Use SAM to center animal'],
637
  label='Input Image Preprocessing',
 
638
  info='untick this, if animal is already centered, e.g. in example images')
639
  # with gr.Column():
640
  # output_processing = gr.CheckboxGroup(['Background Removal'], label='Output Image Postprocessing', value=[])
 
651
  # with gr.Column():
652
  # crop_size = gr.Number(192, label='Crop size')
653
  # crop_size = 192
654
+ run_btn = gr.Button('Reconstruct', variant='primary', interactive=True)
655
  with gr.Row():
656
  view_1 = gr.Image(interactive=False, height=256, show_label=False)
657
  view_2 = gr.Image(interactive=False, height=256, show_label=False)
658
  with gr.Row():
659
+ shape_1 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], height=512, label="Reconstructed Model")
660
+ shape_1_download = gr.File(label="Download Full Reconstructed Model")
661
+ with gr.Row():
662
+ shape_2 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], height=512, label="Bank Base Shape Model")
663
+ shape_2_download = gr.File(label="Download Full Bank Base Shape Model")
664
 
665
  run_btn.click(fn=partial(preprocess, predictor),
666
  inputs=[input_image, input_processing],
667
  outputs=[processed_image_highres, processed_image], queue=True
668
  ).success(fn=partial(run_pipeline, model_items, model_cfgs),
669
+ inputs=[processed_image],
670
+ outputs=[view_1, view_2, shape_1, shape_1_download, shape_2, shape_2_download]
671
  )
672
  demo.queue().launch(share=True, max_threads=80)
673
+ # _, local_url, share_url = demo.queue().launch(share=True, server_name="0.0.0.0", server_port=23425)
674
  # print('local_url: ', local_url)
675
 
676