NLI-Mixer is an attempt to tackle the Natural Language Inference (NLI) task by mixing multiple datasets together.

The approach is simple:

  1. Combine all available NLI data without any domain-dependent re-balancing or re-weighting.
  2. Finetune several SOTA transformers of different sizes (20m parameters to 300m parameters) on the combined data.
  3. Evaluate on challenging NLI datasets.

This model was trained using SentenceTransformers Cross-Encoder class. It is based on microsoft/deberta-v3-base.

Data

20+ NLI datasets were combined to train a binary classification model. The contradiction and neutral labels were combined to form a non-entailment class.

Usage

In Transformers

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from torch.nn.functional import softmax, sigmoid

device = "cuda" if torch.cuda.is_available() else "cpu"

model_name="ragarwal/deberta-v3-base-nli-mixer-binary"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

sentence = "During its monthly call, the National Oceanic and Atmospheric Administration warned of \
increased temperatures and low precipitation" 
labels = ["Computer", "Climate Change", "Tablet", "Football", "Artificial Intelligence", "Global Warming"] 

features = tokenizer([[sentence, l] for l in labels], padding=True, truncation=True, return_tensors="pt")

model.eval()
with torch.no_grad():
    scores = model(**features).logits
    print("Multi-Label:", sigmoid(scores)) #Multi-Label Classification
    print("Single-Label:", softmax(scores, dim=0)) #Single-Label Classification

#Multi-Label: tensor([[0.0412],[0.2436],[0.0394],[0.0020],[0.0050],[0.1424]])
#Single-Label: tensor([[0.0742],[0.5561],[0.0709],[0.0035],[0.0087],[0.2867]])

In Sentence-Transformers

from sentence_transformers import CrossEncoder

model_name="ragarwal/deberta-v3-base-nli-mixer-binary"
model = CrossEncoder(model_name, max_length=256)

sentence = "During its monthly call, the National Oceanic and Atmospheric Administration warned of \
increased temperatures and low precipitation" 
labels = ["Computer", "Climate Change", "Tablet", "Football", "Artificial Intelligence", "Global Warming"] 

scores = model.predict([[sentence, l] for l in labels])
print(scores) 
#array([0.04118565, 0.2435827 , 0.03941465, 0.00203637, 0.00501176, 0.1423797], dtype=float32)
Downloads last month
10
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.