kyleleey commited on
Commit
d7c4a03
1 Parent(s): c2c8e22
Files changed (1) hide show
  1. app.py +7 -10
app.py CHANGED
@@ -17,7 +17,7 @@ import cv2
17
  import time
18
  import numpy as np
19
  import trimesh
20
- from segment_anything import sam_model_registry, SamPredictor
21
 
22
  import random
23
  from pytorch3d import transforms
@@ -56,9 +56,9 @@ if not hasattr(Image, 'Resampling'):
56
  def sam_init():
57
  sam_checkpoint = os.path.join(os.path.dirname(__file__), "sam_pt", "sam_vit_h_4b8939.pth")
58
  model_type = "vit_h"
59
-
60
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=f"cuda:{_GPU_ID}")
61
- predictor = SamPredictor(sam)
62
  return predictor
63
 
64
 
@@ -518,7 +518,7 @@ def run_demo():
518
  torch.cuda.set_device(_GPU_ID)
519
  args.rank = _GPU_ID
520
  args.world_size = 1
521
- args.gpu = os.environ['CUDA_VISIBLE_DEVICES']
522
  device = f'cuda:{_GPU_ID}'
523
 
524
  resolution = (256, 256)
@@ -607,11 +607,6 @@ def run_demo():
607
  shape_1 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Reconstructed Model")
608
  shape_2 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Bank Base Shape Model")
609
 
610
- with gr.Row():
611
- view_gallery = gr.Gallery(interactive=False, show_label=False, container=True, preview=True, allow_preview=False, height=1200)
612
- normal_gallery = gr.Gallery(interactive=False, show_label=False, container=True, preview=True, allow_preview=False, height=1200)
613
-
614
-
615
  run_btn.click(fn=partial(preprocess, predictor),
616
  inputs=[input_image, input_processing],
617
  outputs=[processed_image_highres, processed_image], queue=True
@@ -620,6 +615,8 @@ def run_demo():
620
  outputs=[view_1, view_2, shape_1, shape_2]
621
  )
622
  demo.queue().launch(share=True, max_threads=80)
 
 
623
 
624
 
625
  if __name__ == '__main__':
 
17
  import time
18
  import numpy as np
19
  import trimesh
20
+ from segment_anything import build_sam, SamPredictor
21
 
22
  import random
23
  from pytorch3d import transforms
 
56
  def sam_init():
57
  sam_checkpoint = os.path.join(os.path.dirname(__file__), "sam_pt", "sam_vit_h_4b8939.pth")
58
  model_type = "vit_h"
59
+ # sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=f"cuda:{_GPU_ID}")
60
+ # predictor = SamPredictor(sam)
61
+ predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to("cuda"))
62
  return predictor
63
 
64
 
 
518
  torch.cuda.set_device(_GPU_ID)
519
  args.rank = _GPU_ID
520
  args.world_size = 1
521
+ args.gpu = f'{_GPU_ID}'
522
  device = f'cuda:{_GPU_ID}'
523
 
524
  resolution = (256, 256)
 
607
  shape_1 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Reconstructed Model")
608
  shape_2 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Bank Base Shape Model")
609
 
 
 
 
 
 
610
  run_btn.click(fn=partial(preprocess, predictor),
611
  inputs=[input_image, input_processing],
612
  outputs=[processed_image_highres, processed_image], queue=True
 
615
  outputs=[view_1, view_2, shape_1, shape_2]
616
  )
617
  demo.queue().launch(share=True, max_threads=80)
618
+ # _, local_url, share_url = demo.launch(share=True, server_name="0.0.0.0", server_port=23425)
619
+ # print('local_url: ', local_url)
620
 
621
 
622
  if __name__ == '__main__':