File size: 900 Bytes
a6b176e 6d32a3c a6b176e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
from typing import Dict, List, Any
from sentence_transformers import SentenceTransformer
class EndpointHandler():
def __init__(self, path=""):
print("MODEL INIT")
self.model = SentenceTransformer(path, trust_remote_code=True).cuda()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: List[`str`])
type (:obj: `str`) 'query' || 'doc'
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
inputs = data.pop("inputs",data)
request_type = data.pop("type", 'doc')
if request_type == 'query':
return self.model.encode(inputs, prompt_name='s2p_query')
elif request_type == 'doc':
return self.model.encode(inputs)
else:
raise Exception("Invalid request type")
|