minskiter's picture
feat(model): update model parameters
57034b1
raw
history blame
2.82 kB
from typing import Any, Dict, Tuple
from transformers import Pipeline
from transformers.pipelines.base import GenericTensor
from transformers.utils import ModelOutput
from typing import Union,List
import torch
class EncodePipeline(Pipeline):
def __init__(self, max_length=256,*args, **kwargs):
super().__init__(*args, **kwargs)
self.max_length = max_length
def _sanitize_parameters(self, **pipeline_parameters):
return {},{},{}
def preprocess(self, input: Union[Tuple[str],List[Tuple[str]]], **preprocess_parameters: Dict) -> Dict[str, GenericTensor]:
tensors = self.tokenizer(
input,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
return tensors
def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput:
logits = self.model.encode(**input_tensors)
return logits.tolist()
def postprocess(
self,
model_outputs: ModelOutput,
**postprocess_parameters: Dict
) -> Any:
return model_outputs
class SimilarPipeline(Pipeline):
def __init__(self, max_length=256,*args, **kwargs):
super().__init__(*args, **kwargs)
self.max_length = max_length
def _sanitize_parameters(self, **pipeline_parameters):
return {},{},{}
def preprocess(self, input: Union[Tuple[str],List[Tuple[str]]], **preprocess_parameters: Dict) -> Dict[str, GenericTensor]:
if isinstance(input, list):
a = list(map(lambda x: x[0], input))
b = list(map(lambda x: x[1], input))
else:
a = input[0]
b = input[1]
tensors = self.tokenizer(
a,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
tensors_b = self.tokenizer(
b,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
for key in tensors:
tensors[key] = torch.cat((tensors[key],tensors_b[key]),dim=0)
return tensors
def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput:
_,logits = self.model(**input_tensors)
logits_a = logits[:logits.size(0)//2]
logits_b = logits[logits.size(0)//2:]
logits = torch.nn.functional.cosine_similarity(logits_a, logits_b)
return logits.tolist()
def postprocess(
self,
model_outputs: ModelOutput,
**postprocess_parameters: Dict
) -> Any:
return model_outputs