ZhengPeng7 commited on
Commit
d8c0335
1 Parent(s): 1c4ea53

For users to load in one key.

Browse files
Files changed (3) hide show
  1. BiRefNet_github/models/birefnet.py +1 -1
  2. MyPipe.py +76 -0
  3. config.json +1 -0
BiRefNet_github/models/birefnet.py CHANGED
@@ -114,7 +114,7 @@ class BiRefNet(
114
  def forward(self, x):
115
  scaled_preds, class_preds = self.forward_ori(x)
116
  class_preds_lst = [class_preds]
117
- return [scaled_preds, class_preds_lst] if self.training and 0 else scaled_preds
118
 
119
 
120
  class Decoder(nn.Module):
 
114
  def forward(self, x):
115
  scaled_preds, class_preds = self.forward_ori(x)
116
  class_preds_lst = [class_preds]
117
+ return [scaled_preds, class_preds_lst] if self.training else scaled_preds
118
 
119
 
120
  class Decoder(nn.Module):
MyPipe.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ import torch.nn.functional as F
3
+ from torchvision.transforms.functional import normalize
4
+ import numpy as np
5
+ from transformers import Pipeline
6
+ from transformers.image_utils import load_image
7
+ from skimage import io
8
+ from PIL import Image
9
+
10
+ class RMBGPipe(Pipeline):
11
+ def __init__(self,**kwargs):
12
+ Pipeline.__init__(self,**kwargs)
13
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
+ self.model.to(self.device)
15
+ self.model.eval()
16
+
17
+ def _sanitize_parameters(self, **kwargs):
18
+ # parse parameters
19
+ preprocess_kwargs = {}
20
+ postprocess_kwargs = {}
21
+ if "model_input_size" in kwargs :
22
+ preprocess_kwargs["model_input_size"] = kwargs["model_input_size"]
23
+ if "return_mask" in kwargs:
24
+ postprocess_kwargs["return_mask"] = kwargs["return_mask"]
25
+ return preprocess_kwargs, {}, postprocess_kwargs
26
+
27
+ def preprocess(self,input_image,model_input_size: list=[1024,1024]):
28
+ # preprocess the input
29
+ orig_im = load_image(input_image)
30
+ orig_im = np.array(orig_im)
31
+ orig_im_size = orig_im.shape[0:2]
32
+ preprocessed_image = self.preprocess_image(orig_im, model_input_size).to(self.device)
33
+ inputs = {
34
+ "preprocessed_image":preprocessed_image,
35
+ "orig_im_size":orig_im_size,
36
+ "input_image" : input_image
37
+ }
38
+ return inputs
39
+
40
+ def _forward(self,inputs):
41
+ result = self.model(inputs.pop("preprocessed_image"))
42
+ inputs["result"] = result
43
+ return inputs
44
+
45
+ def postprocess(self,inputs,return_mask:bool=False ):
46
+ result = inputs.pop("result")
47
+ orig_im_size = inputs.pop("orig_im_size")
48
+ input_image = inputs.pop("input_image")
49
+ result_image = self.postprocess_image(result[0][0], orig_im_size)
50
+ pil_im = Image.fromarray(result_image)
51
+ if return_mask ==True :
52
+ return pil_im
53
+ no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
54
+ input_image = load_image(input_image)
55
+ no_bg_image.paste(input_image, mask=pil_im)
56
+ return no_bg_image
57
+
58
+ # utilities functions
59
+ def preprocess_image(self,im: np.ndarray, model_input_size: list=[1024,1024]) -> torch.Tensor:
60
+ # same as utilities.py with minor modification
61
+ if len(im.shape) < 3:
62
+ im = im[:, :, np.newaxis]
63
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
64
+ im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear')
65
+ image = torch.divide(im_tensor,255.0)
66
+ image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
67
+ return image
68
+
69
+ def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
70
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
71
+ ma = torch.max(result)
72
+ mi = torch.min(result)
73
+ result = (result-mi)/(ma-mi)
74
+ im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
75
+ im_array = np.squeeze(im_array)
76
+ return im_array
config.json CHANGED
@@ -9,6 +9,7 @@
9
  },
10
  "custom_pipelines": {
11
  "image-segmentation": {
 
12
  "pt": [
13
  "AutoModelForImageSegmentation"
14
  ],
 
9
  },
10
  "custom_pipelines": {
11
  "image-segmentation": {
12
+ "impl": "MyPipe.RMBGPipe",
13
  "pt": [
14
  "AutoModelForImageSegmentation"
15
  ],