Spaces:
Build error
Build error
import torch | |
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer | |
from transformers.tools import PipelineTool | |
class TextPairClassificationTool(PipelineTool): | |
default_checkpoint = "sgugger/bert-finetuned-mrpc" | |
pre_processor_class = AutoTokenizer | |
model_class = AutoModelForSequenceClassification | |
inputs = ["text", "text"] | |
outputs = ["text"] | |
description = ( | |
"This is a tool that classifies if two texts in English are similar or not using the labels 'equivalent' and " | |
"'not_equivalent'. It takes two inputs named `text` and `second_text` which should be in English and returns " | |
"the predicted label." | |
) | |
def encode(self, text, second_text): | |
return self.pre_processor(text, second_text, return_tensors="pt") | |
def decode(self, outputs): | |
logits = outputs.logits | |
label_id = torch.argmax(logits[0]).item() | |
return self.model.config.id2label[label_id] | |