Spaces:
Running
Running
from typing import List, Optional | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig | |
import torch | |
from haystack.nodes.base import BaseComponent | |
from haystack.modeling.utils import initialize_device_settings | |
from haystack.schema import Document | |
class EntailmentChecker(BaseComponent): | |
""" | |
This node checks the entailment between every document content and the query. | |
It enrichs the documents metadata with entailment informations. | |
It also returns aggregate entailment information. | |
""" | |
outgoing_edges = 1 | |
def __init__( | |
self, | |
model_name_or_path: str = "roberta-large-mnli", | |
model_version: Optional[str] = None, | |
tokenizer: Optional[str] = None, | |
use_gpu: bool = True, | |
batch_size: int = 16, | |
entailment_contradiction_threshold: float = 0.5, | |
): | |
""" | |
Load a Natural Language Inference model from Transformers. | |
:param model_name_or_path: Directory of a saved model or the name of a public model. | |
See https://huggingface.co/models for full list of available models. | |
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. | |
:param tokenizer: Name of the tokenizer (usually the same as model) | |
:param use_gpu: Whether to use GPU (if available). | |
:param batch_size: Number of Documents to be processed at a time. | |
:param entailment_contradiction_threshold: if in the first N documents there is a strong evidence of entailment/contradiction | |
(aggregate entailment or contradiction are greater than the threshold), the less relevant documents are not taken into account | |
""" | |
super().__init__() | |
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False) | |
tokenizer = tokenizer or model_name_or_path | |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) | |
self.model = AutoModelForSequenceClassification.from_pretrained( | |
pretrained_model_name_or_path=model_name_or_path, revision=model_version | |
) | |
self.batch_size = batch_size | |
self.entailment_contradiction_threshold = entailment_contradiction_threshold | |
self.model.to(str(self.devices[0])) | |
id2label = AutoConfig.from_pretrained(model_name_or_path).id2label | |
self.labels = [id2label[k].lower() for k in sorted(id2label)] | |
if "entailment" not in self.labels: | |
raise ValueError( | |
"The model config must contain entailment value in the id2label dict." | |
) | |
def run(self, query: str, documents: List[Document]): | |
scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0 | |
for i, doc in enumerate(documents): | |
entailment_info = self.get_entailment(premise=doc.content, hypotesis=query) | |
doc.meta["entailment_info"] = entailment_info | |
scores += doc.score | |
con, neu, ent = ( | |
entailment_info["contradiction"], | |
entailment_info["neutral"], | |
entailment_info["entailment"], | |
) | |
agg_con += con * doc.score | |
agg_neu += neu * doc.score | |
agg_ent += ent * doc.score | |
# if in the first documents there is a strong evidence of entailment/contradiction, | |
# there is no need to consider less relevant documents | |
if max(agg_con, agg_ent) / scores > self.entailment_contradiction_threshold: | |
break | |
aggregate_entailment_info = { | |
"contradiction": round(agg_con / scores, 2), | |
"neutral": round(agg_neu / scores, 2), | |
"entailment": round(agg_ent / scores, 2), | |
} | |
entailment_checker_result = { | |
"documents": documents[: i + 1], | |
"aggregate_entailment_info": aggregate_entailment_info, | |
} | |
return entailment_checker_result, "output_1" | |
def run_batch(self, queries: List[str], documents: List[Document]): | |
pass | |
def get_entailment(self, premise, hypotesis): | |
with torch.inference_mode(): | |
inputs = self.tokenizer( | |
f"{premise}{self.tokenizer.sep_token}{hypotesis}", return_tensors="pt" | |
).to(self.devices[0]) | |
out = self.model(**inputs) | |
logits = out.logits | |
probs = ( | |
torch.nn.functional.softmax(logits, dim=-1)[0, :].detach().cpu().numpy() | |
) | |
entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)} | |
return entailment_dict | |