sgugger commited on
Commit
e27ee71
·
1 Parent(s): 2204461

Upload tool

Browse files
Files changed (2) hide show
  1. pair_classification_tool.py +23 -0
  2. tool_config.json +7 -0
pair_classification_tool.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
3
+ from transformers.tools import PipelineTool
4
+
5
+
6
+ class TextPairClassificationTool(PipelineTool):
7
+ default_checkpoint = "sgugger/bert-finetuned-mrpc"
8
+ pre_processor_class = AutoTokenizer
9
+ model_class = AutoModelForSequenceClassification
10
+
11
+ description = (
12
+ "This is a tool that classifies if two texts in English are similar or not using the labels 'equivalent' and "
13
+ "'not_equivalent'. It takes two inputs named `text` and `second_text` which should be in English and returns "
14
+ "the predicted label."
15
+ )
16
+
17
+ def encode(self, text, second_text):
18
+ return self.pre_processor(text, second_text, return_tensors="pt")
19
+
20
+ def decode(self, outputs):
21
+ logits = outputs.logits
22
+ label_id = torch.argmax(logits[0]).item()
23
+ return self.model.config.id2label[label_id]
tool_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "text_pair_classification": {
3
+ "description": "This is a tool that classifies if two texts in English are similar or not using the labels 'equivalent' and 'not_equivalent'. It takes two inputs named `text` and `second_text` which should be in English and returns the predicted label.",
4
+ "name": "pair_classifier",
5
+ "tool_class": "pair_classification_tool.TextPairClassificationTool"
6
+ }
7
+ }