|
from typing import Any, Dict, Tuple |
|
from transformers import Pipeline |
|
from transformers.pipelines.base import GenericTensor |
|
from transformers.utils import ModelOutput |
|
from typing import Union,List |
|
|
|
class SimilarPipeline(Pipeline): |
|
def __init__(self, max_length=512,*args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.max_length = max_length |
|
|
|
def _sanitize_parameters(self, **pipeline_parameters): |
|
return {},{},{} |
|
|
|
def preprocess(self, input: Union[Tuple[str],List[Tuple[str]]], **preprocess_parameters: Dict) -> Dict[str, GenericTensor]: |
|
if isinstance(input, list): |
|
a = list(map(lambda x: x[0], input)) |
|
b = list(map(lambda x: x[1], input)) |
|
else: |
|
a = input[0] |
|
b = input[1] |
|
tensors = self.tokenizer( |
|
a, |
|
b, |
|
max_length=self.max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
return tensors |
|
|
|
def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput: |
|
_,logits = self.model(**input_tensors) |
|
return logits.tolist() |
|
|
|
def postprocess( |
|
self, |
|
model_outputs: ModelOutput, |
|
**postprocess_parameters: Dict |
|
) -> Any: |
|
return model_outputs |