File size: 1,263 Bytes
8360ec7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
# Load model directly
# Sentiment analysis pipeline
# classifier = pipeline("sentiment-analysis", model="roberta-large-mnli")
tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")
model = AutoModelForSequenceClassification.from_pretrained("roberta-large-mnli")
classifier = pipeline('sentiment-analysis', model=model, tokenizer=tokenizer)
nli_labelmap = {
"NEUTRAL": 3,
"CONTRADICTION":2,
"ENTAILMENT": 1
}
nli2stance = {
"NEUTRAL": 0,
"CONTRADICTION": -1,
"ENTAILMENT": 1
}
stance_map = {
'irrelevant': 3,
'refute': 2,
'partially-support': 1,
'completely-support': 1
}
def nli_infer(premise, hypothesis):
# predict one example by nli model
try:
input = "<s>{}</s></s>{}</s></s>".format(premise, hypothesis)
pred = classifier(input)
# print(pred)
except:
# token length > 514
L = len(premise)
premise = premise[:int(L/2)]
input = "<s>{}</s></s>{}</s></s>".format(premise, hypothesis)
pred = classifier(input)
# print(pred)
# [{'label': 'CONTRADICTION', 'score': 0.9992701411247253}]
return nli2stance[pred[0]['label']] |