ZhengPeng7 commited on
Commit
ca16a0e
·
1 Parent(s): fb60ce9

Add inference endpoint feature.

Browse files
Files changed (3) hide show
  1. README.md +0 -4
  2. handler.py +12 -2
  3. requirements.txt +0 -2
README.md CHANGED
@@ -8,10 +8,6 @@ tags:
8
  - model_hub_mixin
9
  repo_url: https://github.com/ZhengPeng7/BiRefNet-legacy
10
  pipeline_tag: image-to-image
11
- widget:
12
- - src: >-
13
- https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg
14
- example_title: Portrait
15
  ---
16
  <h1 align="center">Bilateral Reference for High-Resolution Dichotomous Image Segmentation</h1>
17
 
 
8
  - model_hub_mixin
9
  repo_url: https://github.com/ZhengPeng7/BiRefNet-legacy
10
  pipeline_tag: image-to-image
 
 
 
 
11
  ---
12
  <h1 align="center">Bilateral Reference for High-Resolution Dichotomous Image Segmentation</h1>
13
 
handler.py CHANGED
@@ -9,7 +9,6 @@ from PIL import Image
9
  import torch
10
  from torchvision import transforms
11
  from transformers import AutoModelForImageSegmentation
12
- from loadimg import load_img
13
 
14
  torch.set_float32_matmul_precision(["high", "highest"][0])
15
 
@@ -105,7 +104,18 @@ class EndpointHandler():
105
  A :obj:`list` | `dict`: will be serialized and returned
106
  """
107
  print('data["inputs"] = ', data["inputs"])
108
- image = load_img(data["inputs"]).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
109
  # Preprocess the image
110
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
111
  image_proc = image_preprocessor.proc(image)
 
9
  import torch
10
  from torchvision import transforms
11
  from transformers import AutoModelForImageSegmentation
 
12
 
13
  torch.set_float32_matmul_precision(["high", "highest"][0])
14
 
 
104
  A :obj:`list` | `dict`: will be serialized and returned
105
  """
106
  print('data["inputs"] = ', data["inputs"])
107
+ image_src = data["inputs"]
108
+ if isinstance(image_src, str):
109
+ if os.path.isfile(image_src):
110
+ image_ori = Image.open(image_src)
111
+ else:
112
+ response = requests.get(image_src)
113
+ image_data = BytesIO(response.content)
114
+ image_ori = Image.open(image_data)
115
+ else:
116
+ image_ori = Image.fromarray(image_src)
117
+
118
+ image = image_ori.convert('RGB')
119
  # Preprocess the image
120
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
121
  image_proc = image_preprocessor.proc(image)
requirements.txt CHANGED
@@ -16,5 +16,3 @@ prettytable
16
  transformers
17
  huggingface-hub>0.25
18
  accelerate
19
-
20
- loadimg
 
16
  transformers
17
  huggingface-hub>0.25
18
  accelerate