kyleleey commited on
Commit
166562d
1 Parent(s): 66abaaa

fix sam mask

Browse files
Files changed (1) hide show
  1. app.py +14 -24
app.py CHANGED
@@ -80,8 +80,18 @@ def sam_segment(predictor, input_image, *bbox_coords):
80
  out_image_bbox = out_image.copy()
81
  out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255
82
  torch.cuda.empty_cache()
83
- return Image.fromarray(out_image_bbox, mode='RGB')
84
- # return Image.fromarray(out_image_bbox, mode='RGBA')
 
 
 
 
 
 
 
 
 
 
85
 
86
 
87
  def expand2square(pil_img, background_color):
@@ -113,28 +123,8 @@ def preprocess(predictor, input_image, chk_group=None, segment=False):
113
  x_max = int(x_nonzero[0].max())
114
  y_max = int(y_nonzero[0].max())
115
  input_image = sam_segment(predictor, input_image.convert('RGB'), x_min, y_min, x_max, y_max)
116
- # Rescale and recenter
117
- # if rescale:
118
- # image_arr = np.array(input_image)
119
- # in_w, in_h = image_arr.shape[:2]
120
- # out_res = min(RES, max(in_w, in_h))
121
- # ret, mask = cv2.threshold(np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY)
122
- # x, y, w, h = cv2.boundingRect(mask)
123
- # max_size = max(w, h)
124
- # ratio = 0.75
125
- # side_len = int(max_size / ratio)
126
- # padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
127
- # center = side_len//2
128
- # padded_image[center-h//2:center-h//2+h, center-w//2:center-w//2+w] = image_arr[y:y+h, x:x+w]
129
- # rgba = Image.fromarray(padded_image).resize((out_res, out_res), Image.LANCZOS)
130
-
131
- # rgba_arr = np.array(rgba) / 255.0
132
- # rgb = rgba_arr[...,:3] * rgba_arr[...,-1:] + (1 - rgba_arr[...,-1:])
133
- # input_image = Image.fromarray((rgb * 255).astype(np.uint8))
134
- # else:
135
- # input_image = expand2square(input_image, (127, 127, 127, 0))
136
 
137
- input_image = expand2square(input_image, (0, 0, 0))
138
  return input_image, input_image.resize((256, 256), Image.Resampling.LANCZOS)
139
 
140
 
@@ -545,7 +535,7 @@ def run_pipeline(model_items, cfgs, input_img):
545
  mesh_image = save_images(shading, mask_pred)
546
  mesh_bones_image = save_images(image_with_bones, mask_final)
547
 
548
- shape_glb, shape_obj = process_mesh(shape, 'reconstruced_shape')
549
  base_shape_glb, base_shape_obj = process_mesh(prior_shape, 'reconstructed_base_shape')
550
 
551
  return mesh_image, mesh_bones_image, shape_glb, shape_obj, base_shape_glb, base_shape_obj
 
80
  out_image_bbox = out_image.copy()
81
  out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255
82
  torch.cuda.empty_cache()
83
+ # return Image.fromarray(out_image_bbox, mode='RGB')
84
+
85
+ x_nonzero = np.nonzero(masks_bbox[-1].astype(np.uint8).sum(axis=0))
86
+ y_nonzero = np.nonzero(masks_bbox[-1].astype(np.uint8).sum(axis=1))
87
+ x_min = int(x_nonzero[0].min())
88
+ y_min = int(y_nonzero[0].min())
89
+ x_max = int(x_nonzero[0].max())
90
+ y_max = int(y_nonzero[0].max())
91
+
92
+ out_image_bbox = out_image_bbox[y_min:y_max, x_min:x_max]
93
+
94
+ return Image.fromarray(out_image_bbox, mode='RGBA')
95
 
96
 
97
  def expand2square(pil_img, background_color):
 
123
  x_max = int(x_nonzero[0].max())
124
  y_max = int(y_nonzero[0].max())
125
  input_image = sam_segment(predictor, input_image.convert('RGB'), x_min, y_min, x_max, y_max)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ input_image = expand2square(input_image, (0, 0, 0, 0))
128
  return input_image, input_image.resize((256, 256), Image.Resampling.LANCZOS)
129
 
130
 
 
535
  mesh_image = save_images(shading, mask_pred)
536
  mesh_bones_image = save_images(image_with_bones, mask_final)
537
 
538
+ shape_glb, shape_obj = process_mesh(shape, 'reconstructed_shape')
539
  base_shape_glb, base_shape_obj = process_mesh(prior_shape, 'reconstructed_base_shape')
540
 
541
  return mesh_image, mesh_bones_image, shape_glb, shape_obj, base_shape_glb, base_shape_obj