minskiter's picture
feat(similar.py): update pipeline
fdb0b54
from typing import Any, Dict, Tuple
from transformers import Pipeline
from transformers.pipelines.base import GenericTensor
from transformers.utils import ModelOutput
from typing import Union,List
class SimilarPipeline(Pipeline):
def __init__(self, max_length=512,*args, **kwargs):
super().__init__(*args, **kwargs)
self.max_length = max_length
def _sanitize_parameters(self, **pipeline_parameters):
return {},{},{}
def preprocess(self, input: Union[Tuple[str],List[Tuple[str]]], **preprocess_parameters: Dict) -> Dict[str, GenericTensor]:
if isinstance(input, list):
a = list(map(lambda x: x[0], input))
b = list(map(lambda x: x[1], input))
else:
a = input[0]
b = input[1]
tensors = self.tokenizer(
a,
b,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
return tensors
def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput:
_,logits = self.model(**input_tensors)
return logits.tolist()
def postprocess(
self,
model_outputs: ModelOutput,
**postprocess_parameters: Dict
) -> Any:
return model_outputs