kxhit commited on
Commit
48f1c75
·
1 Parent(s): 8db7bce
Files changed (3) hide show
  1. app.py +19 -1
  2. dust3r/utils/image.py +4 -6
  3. requirements.txt +1 -0
app.py CHANGED
@@ -73,7 +73,24 @@ from dataset import get_pose
73
  from CN_encoder import CN_encoder
74
  from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline
75
  from segment_anything import sam_model_registry, SamPredictor
 
76
  import rembg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  pretrained_model_name_or_path = "kxic/EscherNet_demo"
79
  resolution = 256
@@ -128,7 +145,8 @@ def sam_init():
128
  predictor = SamPredictor(sam)
129
  return predictor
130
 
131
- rembg_session = rembg.new_session()
 
132
  predictor = sam_init()
133
 
134
 
 
73
  from CN_encoder import CN_encoder
74
  from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline
75
  from segment_anything import sam_model_registry, SamPredictor
76
+
77
  import rembg
78
+ from carvekit.api.high import HiInterface
79
+
80
+ def create_carvekit_interface():
81
+ # Check doc strings for more information
82
+ interface = HiInterface(object_type="object", # Can be "object" or "hairs-like".
83
+ batch_size_seg=6,
84
+ batch_size_matting=1,
85
+ device=device,
86
+ seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
87
+ matting_mask_size=2048,
88
+ trimap_prob_threshold=231,
89
+ trimap_dilation=30,
90
+ trimap_erosion_iters=5,
91
+ fp16=True)
92
+
93
+ return interface
94
 
95
  pretrained_model_name_or_path = "kxic/EscherNet_demo"
96
  resolution = 256
 
145
  predictor = SamPredictor(sam)
146
  return predictor
147
 
148
+ # rembg_session = rembg.new_session()
149
+ rembg_session = create_carvekit_interface()
150
  predictor = sam_init()
151
 
152
 
dust3r/utils/image.py CHANGED
@@ -118,12 +118,10 @@ def load_images(folder_or_list, size, square_ok=False, verbose=True, do_remove_b
118
  img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert('RGB')
119
  # remove background if needed
120
  if do_remove_background:
121
- # if rembg_session is None:
122
- # rembg_session = rembg.new_session()
123
- # image = rembg.remove(img, session=rembg_session)
124
- # foreground = np.array(image)[..., -1] > 127
125
-
126
- image_nobg = remove(img, alpha_matting=True, session=rembg_session)
127
  arr = np.asarray(image_nobg)[:, :, -1]
128
  x_nonzero = np.nonzero(arr.sum(axis=0))
129
  y_nonzero = np.nonzero(arr.sum(axis=1))
 
118
  img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert('RGB')
119
  # remove background if needed
120
  if do_remove_background:
121
+ # use rembg
122
+ # image_nobg = remove(img, alpha_matting=True, session=rembg_session)
123
+ # use carvekit
124
+ image_nobg = rembg_session([img])[0]
 
 
125
  arr = np.asarray(image_nobg)[:, :, -1]
126
  x_nonzero = np.nonzero(arr.sum(axis=0))
127
  y_nonzero = np.nonzero(arr.sum(axis=1))
requirements.txt CHANGED
@@ -7,6 +7,7 @@ transformers
7
  gradio
8
  spaces
9
  rembg==2.0.56
 
10
  open3d
11
  trimesh
12
  einops
 
7
  gradio
8
  spaces
9
  rembg==2.0.56
10
+ carvekit-colab==4.1.0
11
  open3d
12
  trimesh
13
  einops