ZhengPeng7 commited on
Commit
621c740
1 Parent(s): a0e537e

Add fast-foreground-estimation in masking image.

Browse files
Files changed (1) hide show
  1. app.py +39 -8
app.py CHANGED
@@ -23,6 +23,40 @@ torch.jit.script = lambda f: f
23
 
24
  device = "cuda" if torch.cuda.is_available() else "CPU"
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def array_to_pil_image(image: np.ndarray, size: Tuple[int, int] = (1024, 1024)) -> Image.Image:
28
  image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
@@ -114,19 +148,16 @@ def predict(images, resolution, weights_file):
114
  if device == 'cuda':
115
  scaled_pred_tensor = scaled_pred_tensor.cpu()
116
 
117
- # Resize the prediction to match the original image shape
118
- pred = torch.nn.functional.interpolate(scaled_pred_tensor, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
119
-
120
- # Apply the prediction mask to the original image
121
- image_pil = image_pil.resize(pred.shape[::-1])
122
- pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
123
- image_masked = (pred * np.array(image_pil)).astype(np.uint8)
124
 
125
  torch.cuda.empty_cache()
126
 
127
  if tab_is_batch:
128
  save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
129
- cv2.imwrite(save_file_path, cv2.cvtColor(image_masked, cv2.COLOR_RGB2BGR))
130
  save_paths.append(save_file_path)
131
 
132
  if tab_is_batch:
 
23
 
24
  device = "cuda" if torch.cuda.is_available() else "CPU"
25
 
26
+ ### image_proc.py
27
+ def refine_foreground(image, mask, r=90):
28
+ if mask.size != image.size:
29
+ mask = mask.resize(image.size)
30
+ image = np.array(image) / 255.0
31
+ mask = np.array(mask) / 255.0
32
+ estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
33
+ image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
34
+ return image_masked
35
+
36
+
37
+ def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
38
+ # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
39
+ alpha = alpha[:, :, None]
40
+ F, blur_B = FB_blur_fusion_foreground_estimator(
41
+ image, image, image, alpha, r)
42
+ return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
43
+
44
+
45
+ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
46
+ if isinstance(image, Image.Image):
47
+ image = np.array(image) / 255.0
48
+ blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
49
+
50
+ blurred_FA = cv2.blur(F * alpha, (r, r))
51
+ blurred_F = blurred_FA / (blurred_alpha + 1e-5)
52
+
53
+ blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
54
+ blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
55
+ F = blurred_F + alpha * \
56
+ (image - alpha * blurred_F - (1 - alpha) * blurred_B)
57
+ F = np.clip(F, 0, 1)
58
+ return F, blurred_B
59
+
60
 
61
  def array_to_pil_image(image: np.ndarray, size: Tuple[int, int] = (1024, 1024)) -> Image.Image:
62
  image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
 
148
  if device == 'cuda':
149
  scaled_pred_tensor = scaled_pred_tensor.cpu()
150
 
151
+ # Show Results
152
+ pred_pil = transforms.ToPILImage()(pred)
153
+ image_masked = refine_foreground(image, pred_pil)
154
+ image_masked.putalpha(pred_pil.resize(image.size))
 
 
 
155
 
156
  torch.cuda.empty_cache()
157
 
158
  if tab_is_batch:
159
  save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
160
+ image_masked.save(save_file_path)
161
  save_paths.append(save_file_path)
162
 
163
  if tab_is_batch: