ZhengPeng7 commited on
Commit
1e0fb18
·
1 Parent(s): 63c51ee

Add inference endpoint feature.

Browse files
Files changed (1) hide show
  1. handler.py +4 -5
handler.py CHANGED
@@ -75,9 +75,7 @@ usage_to_weights_file = {
75
  }
76
 
77
  usage = 'General'
78
- birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file[usage])), trust_remote_code=True)
79
- birefnet.to(device)
80
- birefnet.eval()
81
 
82
  # Set resolution
83
  if usage in ['General-Lite-2K']:
@@ -91,9 +89,10 @@ else:
91
  class EndpointHandler():
92
  def __init__(self, path=""):
93
  self.birefnet = AutoModelForImageSegmentation.from_pretrained(
94
- "ZhengPeng7/BiRefNet", trust_remote_code=True
95
  )
96
  self.birefnet.to(device)
 
97
 
98
  def __call__(self, data: Dict[str, Any]):
99
  """
@@ -123,7 +122,7 @@ class EndpointHandler():
123
 
124
  # Prediction
125
  with torch.no_grad():
126
- preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
127
  pred = preds[0].squeeze()
128
 
129
  # Show Results
 
75
  }
76
 
77
  usage = 'General'
78
+ model_repo = '/'.join(('zhengpeng7', usage_to_weights_file[usage]))
 
 
79
 
80
  # Set resolution
81
  if usage in ['General-Lite-2K']:
 
89
  class EndpointHandler():
90
  def __init__(self, path=""):
91
  self.birefnet = AutoModelForImageSegmentation.from_pretrained(
92
+ model_repo, trust_remote_code=True
93
  )
94
  self.birefnet.to(device)
95
+ self.birefnet.eval()
96
 
97
  def __call__(self, data: Dict[str, Any]):
98
  """
 
122
 
123
  # Prediction
124
  with torch.no_grad():
125
+ preds = self.birefnet(image_proc.to(device))[-1].sigmoid().cpu()
126
  pred = preds[0].squeeze()
127
 
128
  # Show Results