anakin87 commited on
Commit
4c41de2
·
1 Parent(s): 08d96b7

refactor EntailmentChecker: only relevant documents are used

Browse files
Rock_fact_checker.py CHANGED
@@ -97,8 +97,8 @@ def main():
97
 
98
  # Display results
99
  if st.session_state.results:
100
- results = st.session_state.results
101
- docs, agg_entailment_info = results["documents"], results["agg_entailment_info"]
102
 
103
  # show different messages depending on entailment results
104
  max_key = max(agg_entailment_info, key=agg_entailment_info.get)
@@ -107,12 +107,11 @@ def main():
107
 
108
  st.markdown(f"###### Aggregate entailment information:")
109
  col1, col2 = st.columns([2, 1])
110
- agg_entailment_info = results["agg_entailment_info"]
111
  fig = create_ternary_plot(agg_entailment_info)
112
  with col1:
113
  st.plotly_chart(fig, use_container_width=True)
114
  with col2:
115
- st.write(results["agg_entailment_info"])
116
 
117
  st.markdown(f"###### Most Relevant snippets:")
118
  df, urls = create_df_for_relevant_snippets(docs)
 
97
 
98
  # Display results
99
  if st.session_state.results:
100
+ docs = st.session_state.results["documents"]
101
+ agg_entailment_info = st.session_state.results["aggregate_entailment_info"]
102
 
103
  # show different messages depending on entailment results
104
  max_key = max(agg_entailment_info, key=agg_entailment_info.get)
 
107
 
108
  st.markdown(f"###### Aggregate entailment information:")
109
  col1, col2 = st.columns([2, 1])
 
110
  fig = create_ternary_plot(agg_entailment_info)
111
  with col1:
112
  st.plotly_chart(fig, use_container_width=True)
113
  with col2:
114
+ st.write(agg_entailment_info)
115
 
116
  st.markdown(f"###### Most Relevant snippets:")
117
  df, urls = create_df_for_relevant_snippets(docs)
app_utils/backend_utils.py CHANGED
@@ -44,7 +44,11 @@ def start_haystack():
44
  embedding_model=RETRIEVER_MODEL,
45
  model_format=RETRIEVER_MODEL_FORMAT,
46
  )
47
- entailment_checker = EntailmentChecker(model_name_or_path=NLI_MODEL, use_gpu=False)
 
 
 
 
48
 
49
  pipe = Pipeline()
50
  pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
@@ -60,30 +64,4 @@ pipe = start_haystack()
60
  def query(statement: str, retriever_top_k: int = 5):
61
  """Run query and verify statement"""
62
  params = {"retriever": {"top_k": retriever_top_k}}
63
- results = pipe.run(statement, params=params)
64
-
65
- scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
66
- for i, doc in enumerate(results["documents"]):
67
- scores += doc.score
68
- ent_info = doc.meta["entailment_info"]
69
- con, neu, ent = (
70
- ent_info["contradiction"],
71
- ent_info["neutral"],
72
- ent_info["entailment"],
73
- )
74
- agg_con += con * doc.score
75
- agg_neu += neu * doc.score
76
- agg_ent += ent * doc.score
77
-
78
- # if in the first documents there is a strong evidence of entailment/contradiction,
79
- # there is no need to consider less relevant documents
80
- if max(agg_con, agg_ent) / scores > 0.5:
81
- results["documents"] = results["documents"][: i + 1]
82
- break
83
-
84
- results["agg_entailment_info"] = {
85
- "contradiction": round(agg_con / scores, 2),
86
- "neutral": round(agg_neu / scores, 2),
87
- "entailment": round(agg_ent / scores, 2),
88
- }
89
- return results
 
44
  embedding_model=RETRIEVER_MODEL,
45
  model_format=RETRIEVER_MODEL_FORMAT,
46
  )
47
+ entailment_checker = EntailmentChecker(
48
+ model_name_or_path=NLI_MODEL,
49
+ use_gpu=False,
50
+ entailment_contradiction_threshold=0.5,
51
+ )
52
 
53
  pipe = Pipeline()
54
  pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
 
64
  def query(statement: str, retriever_top_k: int = 5):
65
  """Run query and verify statement"""
66
  params = {"retriever": {"top_k": retriever_top_k}}
67
+ return pipe.run(statement, params=params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_utils/entailment_checker.py CHANGED
@@ -4,13 +4,14 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer, Auto
4
  import torch
5
  from haystack.nodes.base import BaseComponent
6
  from haystack.modeling.utils import initialize_device_settings
7
- from haystack.schema import Document, Answer, Span
8
 
9
 
10
  class EntailmentChecker(BaseComponent):
11
  """
12
  This node checks the entailment between every document content and the query.
13
- It enrichs the documents metadata with entailment_info
 
14
  """
15
 
16
  outgoing_edges = 1
@@ -22,6 +23,7 @@ class EntailmentChecker(BaseComponent):
22
  tokenizer: Optional[str] = None,
23
  use_gpu: bool = True,
24
  batch_size: int = 16,
 
25
  ):
26
  """
27
  Load a Natural Language Inference model from Transformers.
@@ -31,7 +33,9 @@ class EntailmentChecker(BaseComponent):
31
  :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
32
  :param tokenizer: Name of the tokenizer (usually the same as model)
33
  :param use_gpu: Whether to use GPU (if available).
34
- # :param batch_size: Number of Documents to be processed at a time.
 
 
35
  """
36
  super().__init__()
37
 
@@ -43,6 +47,7 @@ class EntailmentChecker(BaseComponent):
43
  pretrained_model_name_or_path=model_name_or_path, revision=model_version
44
  )
45
  self.batch_size = batch_size
 
46
  self.model.to(str(self.devices[0]))
47
 
48
  id2label = AutoConfig.from_pretrained(model_name_or_path).id2label
@@ -53,12 +58,41 @@ class EntailmentChecker(BaseComponent):
53
  )
54
 
55
  def run(self, query: str, documents: List[Document]):
56
- for doc in documents:
57
- entailment_dict = self.get_entailment(premise=doc.content, hypotesis=query)
58
- doc.meta["entailment_info"] = entailment_dict
59
- return {"documents": documents}, "output_1"
60
 
61
- def run_batch():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  pass
63
 
64
  def get_entailment(self, premise, hypotesis):
 
4
  import torch
5
  from haystack.nodes.base import BaseComponent
6
  from haystack.modeling.utils import initialize_device_settings
7
+ from haystack.schema import Document
8
 
9
 
10
  class EntailmentChecker(BaseComponent):
11
  """
12
  This node checks the entailment between every document content and the query.
13
+ It enrichs the documents metadata with entailment informations.
14
+ It also returns aggregate entailment information.
15
  """
16
 
17
  outgoing_edges = 1
 
23
  tokenizer: Optional[str] = None,
24
  use_gpu: bool = True,
25
  batch_size: int = 16,
26
+ entailment_contradiction_threshold: float = 0.5,
27
  ):
28
  """
29
  Load a Natural Language Inference model from Transformers.
 
33
  :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
34
  :param tokenizer: Name of the tokenizer (usually the same as model)
35
  :param use_gpu: Whether to use GPU (if available).
36
+ :param batch_size: Number of Documents to be processed at a time.
37
+ :param entailment_contradiction_threshold: if in the first N documents there is a strong evidence of entailment/contradiction
38
+ (aggregate entailment or contradiction are greater than the threshold), the less relevant documents are not taken into account
39
  """
40
  super().__init__()
41
 
 
47
  pretrained_model_name_or_path=model_name_or_path, revision=model_version
48
  )
49
  self.batch_size = batch_size
50
+ self.entailment_contradiction_threshold = entailment_contradiction_threshold
51
  self.model.to(str(self.devices[0]))
52
 
53
  id2label = AutoConfig.from_pretrained(model_name_or_path).id2label
 
58
  )
59
 
60
  def run(self, query: str, documents: List[Document]):
 
 
 
 
61
 
62
+ scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
63
+ for i, doc in enumerate(documents):
64
+ entailment_info = self.get_entailment(premise=doc.content, hypotesis=query)
65
+ doc.meta["entailment_info"] = entailment_info
66
+
67
+ scores += doc.score
68
+ con, neu, ent = (
69
+ entailment_info["contradiction"],
70
+ entailment_info["neutral"],
71
+ entailment_info["entailment"],
72
+ )
73
+ agg_con += con * doc.score
74
+ agg_neu += neu * doc.score
75
+ agg_ent += ent * doc.score
76
+
77
+ # if in the first documents there is a strong evidence of entailment/contradiction,
78
+ # there is no need to consider less relevant documents
79
+ if max(agg_con, agg_ent) / scores > self.entailment_contradiction_threshold:
80
+ break
81
+
82
+ aggregate_entailment_info = {
83
+ "contradiction": round(agg_con / scores, 2),
84
+ "neutral": round(agg_neu / scores, 2),
85
+ "entailment": round(agg_ent / scores, 2),
86
+ }
87
+
88
+ entailment_checker_result = {
89
+ "documents": documents[: i + 1],
90
+ "aggregate_entailment_info": aggregate_entailment_info,
91
+ }
92
+
93
+ return entailment_checker_result, "output_1"
94
+
95
+ def run_batch(self, queries: List[str], documents: List[Document]):
96
  pass
97
 
98
  def get_entailment(self, premise, hypotesis):