ZhengPeng7
commited on
Commit
·
1e0fb18
1
Parent(s):
63c51ee
Add inference endpoint feature.
Browse files- handler.py +4 -5
handler.py
CHANGED
@@ -75,9 +75,7 @@ usage_to_weights_file = {
|
|
75 |
}
|
76 |
|
77 |
usage = 'General'
|
78 |
-
|
79 |
-
birefnet.to(device)
|
80 |
-
birefnet.eval()
|
81 |
|
82 |
# Set resolution
|
83 |
if usage in ['General-Lite-2K']:
|
@@ -91,9 +89,10 @@ else:
|
|
91 |
class EndpointHandler():
|
92 |
def __init__(self, path=""):
|
93 |
self.birefnet = AutoModelForImageSegmentation.from_pretrained(
|
94 |
-
|
95 |
)
|
96 |
self.birefnet.to(device)
|
|
|
97 |
|
98 |
def __call__(self, data: Dict[str, Any]):
|
99 |
"""
|
@@ -123,7 +122,7 @@ class EndpointHandler():
|
|
123 |
|
124 |
# Prediction
|
125 |
with torch.no_grad():
|
126 |
-
preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
|
127 |
pred = preds[0].squeeze()
|
128 |
|
129 |
# Show Results
|
|
|
75 |
}
|
76 |
|
77 |
usage = 'General'
|
78 |
+
model_repo = '/'.join(('zhengpeng7', usage_to_weights_file[usage]))
|
|
|
|
|
79 |
|
80 |
# Set resolution
|
81 |
if usage in ['General-Lite-2K']:
|
|
|
89 |
class EndpointHandler():
|
90 |
def __init__(self, path=""):
|
91 |
self.birefnet = AutoModelForImageSegmentation.from_pretrained(
|
92 |
+
model_repo, trust_remote_code=True
|
93 |
)
|
94 |
self.birefnet.to(device)
|
95 |
+
self.birefnet.eval()
|
96 |
|
97 |
def __call__(self, data: Dict[str, Any]):
|
98 |
"""
|
|
|
122 |
|
123 |
# Prediction
|
124 |
with torch.no_grad():
|
125 |
+
preds = self.birefnet(image_proc.to(device))[-1].sigmoid().cpu()
|
126 |
pred = preds[0].squeeze()
|
127 |
|
128 |
# Show Results
|