Spaces:
Running
Running
anakin87
commited on
Commit
·
4c41de2
1
Parent(s):
08d96b7
refactor EntailmentChecker: only relevant documents are used
Browse files- Rock_fact_checker.py +3 -4
- app_utils/backend_utils.py +6 -28
- app_utils/entailment_checker.py +42 -8
Rock_fact_checker.py
CHANGED
@@ -97,8 +97,8 @@ def main():
|
|
97 |
|
98 |
# Display results
|
99 |
if st.session_state.results:
|
100 |
-
|
101 |
-
|
102 |
|
103 |
# show different messages depending on entailment results
|
104 |
max_key = max(agg_entailment_info, key=agg_entailment_info.get)
|
@@ -107,12 +107,11 @@ def main():
|
|
107 |
|
108 |
st.markdown(f"###### Aggregate entailment information:")
|
109 |
col1, col2 = st.columns([2, 1])
|
110 |
-
agg_entailment_info = results["agg_entailment_info"]
|
111 |
fig = create_ternary_plot(agg_entailment_info)
|
112 |
with col1:
|
113 |
st.plotly_chart(fig, use_container_width=True)
|
114 |
with col2:
|
115 |
-
st.write(
|
116 |
|
117 |
st.markdown(f"###### Most Relevant snippets:")
|
118 |
df, urls = create_df_for_relevant_snippets(docs)
|
|
|
97 |
|
98 |
# Display results
|
99 |
if st.session_state.results:
|
100 |
+
docs = st.session_state.results["documents"]
|
101 |
+
agg_entailment_info = st.session_state.results["aggregate_entailment_info"]
|
102 |
|
103 |
# show different messages depending on entailment results
|
104 |
max_key = max(agg_entailment_info, key=agg_entailment_info.get)
|
|
|
107 |
|
108 |
st.markdown(f"###### Aggregate entailment information:")
|
109 |
col1, col2 = st.columns([2, 1])
|
|
|
110 |
fig = create_ternary_plot(agg_entailment_info)
|
111 |
with col1:
|
112 |
st.plotly_chart(fig, use_container_width=True)
|
113 |
with col2:
|
114 |
+
st.write(agg_entailment_info)
|
115 |
|
116 |
st.markdown(f"###### Most Relevant snippets:")
|
117 |
df, urls = create_df_for_relevant_snippets(docs)
|
app_utils/backend_utils.py
CHANGED
@@ -44,7 +44,11 @@ def start_haystack():
|
|
44 |
embedding_model=RETRIEVER_MODEL,
|
45 |
model_format=RETRIEVER_MODEL_FORMAT,
|
46 |
)
|
47 |
-
entailment_checker = EntailmentChecker(
|
|
|
|
|
|
|
|
|
48 |
|
49 |
pipe = Pipeline()
|
50 |
pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
|
@@ -60,30 +64,4 @@ pipe = start_haystack()
|
|
60 |
def query(statement: str, retriever_top_k: int = 5):
|
61 |
"""Run query and verify statement"""
|
62 |
params = {"retriever": {"top_k": retriever_top_k}}
|
63 |
-
|
64 |
-
|
65 |
-
scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
|
66 |
-
for i, doc in enumerate(results["documents"]):
|
67 |
-
scores += doc.score
|
68 |
-
ent_info = doc.meta["entailment_info"]
|
69 |
-
con, neu, ent = (
|
70 |
-
ent_info["contradiction"],
|
71 |
-
ent_info["neutral"],
|
72 |
-
ent_info["entailment"],
|
73 |
-
)
|
74 |
-
agg_con += con * doc.score
|
75 |
-
agg_neu += neu * doc.score
|
76 |
-
agg_ent += ent * doc.score
|
77 |
-
|
78 |
-
# if in the first documents there is a strong evidence of entailment/contradiction,
|
79 |
-
# there is no need to consider less relevant documents
|
80 |
-
if max(agg_con, agg_ent) / scores > 0.5:
|
81 |
-
results["documents"] = results["documents"][: i + 1]
|
82 |
-
break
|
83 |
-
|
84 |
-
results["agg_entailment_info"] = {
|
85 |
-
"contradiction": round(agg_con / scores, 2),
|
86 |
-
"neutral": round(agg_neu / scores, 2),
|
87 |
-
"entailment": round(agg_ent / scores, 2),
|
88 |
-
}
|
89 |
-
return results
|
|
|
44 |
embedding_model=RETRIEVER_MODEL,
|
45 |
model_format=RETRIEVER_MODEL_FORMAT,
|
46 |
)
|
47 |
+
entailment_checker = EntailmentChecker(
|
48 |
+
model_name_or_path=NLI_MODEL,
|
49 |
+
use_gpu=False,
|
50 |
+
entailment_contradiction_threshold=0.5,
|
51 |
+
)
|
52 |
|
53 |
pipe = Pipeline()
|
54 |
pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
|
|
|
64 |
def query(statement: str, retriever_top_k: int = 5):
|
65 |
"""Run query and verify statement"""
|
66 |
params = {"retriever": {"top_k": retriever_top_k}}
|
67 |
+
return pipe.run(statement, params=params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_utils/entailment_checker.py
CHANGED
@@ -4,13 +4,14 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer, Auto
|
|
4 |
import torch
|
5 |
from haystack.nodes.base import BaseComponent
|
6 |
from haystack.modeling.utils import initialize_device_settings
|
7 |
-
from haystack.schema import Document
|
8 |
|
9 |
|
10 |
class EntailmentChecker(BaseComponent):
|
11 |
"""
|
12 |
This node checks the entailment between every document content and the query.
|
13 |
-
It enrichs the documents metadata with
|
|
|
14 |
"""
|
15 |
|
16 |
outgoing_edges = 1
|
@@ -22,6 +23,7 @@ class EntailmentChecker(BaseComponent):
|
|
22 |
tokenizer: Optional[str] = None,
|
23 |
use_gpu: bool = True,
|
24 |
batch_size: int = 16,
|
|
|
25 |
):
|
26 |
"""
|
27 |
Load a Natural Language Inference model from Transformers.
|
@@ -31,7 +33,9 @@ class EntailmentChecker(BaseComponent):
|
|
31 |
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
32 |
:param tokenizer: Name of the tokenizer (usually the same as model)
|
33 |
:param use_gpu: Whether to use GPU (if available).
|
34 |
-
|
|
|
|
|
35 |
"""
|
36 |
super().__init__()
|
37 |
|
@@ -43,6 +47,7 @@ class EntailmentChecker(BaseComponent):
|
|
43 |
pretrained_model_name_or_path=model_name_or_path, revision=model_version
|
44 |
)
|
45 |
self.batch_size = batch_size
|
|
|
46 |
self.model.to(str(self.devices[0]))
|
47 |
|
48 |
id2label = AutoConfig.from_pretrained(model_name_or_path).id2label
|
@@ -53,12 +58,41 @@ class EntailmentChecker(BaseComponent):
|
|
53 |
)
|
54 |
|
55 |
def run(self, query: str, documents: List[Document]):
|
56 |
-
for doc in documents:
|
57 |
-
entailment_dict = self.get_entailment(premise=doc.content, hypotesis=query)
|
58 |
-
doc.meta["entailment_info"] = entailment_dict
|
59 |
-
return {"documents": documents}, "output_1"
|
60 |
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
pass
|
63 |
|
64 |
def get_entailment(self, premise, hypotesis):
|
|
|
4 |
import torch
|
5 |
from haystack.nodes.base import BaseComponent
|
6 |
from haystack.modeling.utils import initialize_device_settings
|
7 |
+
from haystack.schema import Document
|
8 |
|
9 |
|
10 |
class EntailmentChecker(BaseComponent):
|
11 |
"""
|
12 |
This node checks the entailment between every document content and the query.
|
13 |
+
It enrichs the documents metadata with entailment informations.
|
14 |
+
It also returns aggregate entailment information.
|
15 |
"""
|
16 |
|
17 |
outgoing_edges = 1
|
|
|
23 |
tokenizer: Optional[str] = None,
|
24 |
use_gpu: bool = True,
|
25 |
batch_size: int = 16,
|
26 |
+
entailment_contradiction_threshold: float = 0.5,
|
27 |
):
|
28 |
"""
|
29 |
Load a Natural Language Inference model from Transformers.
|
|
|
33 |
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
34 |
:param tokenizer: Name of the tokenizer (usually the same as model)
|
35 |
:param use_gpu: Whether to use GPU (if available).
|
36 |
+
:param batch_size: Number of Documents to be processed at a time.
|
37 |
+
:param entailment_contradiction_threshold: if in the first N documents there is a strong evidence of entailment/contradiction
|
38 |
+
(aggregate entailment or contradiction are greater than the threshold), the less relevant documents are not taken into account
|
39 |
"""
|
40 |
super().__init__()
|
41 |
|
|
|
47 |
pretrained_model_name_or_path=model_name_or_path, revision=model_version
|
48 |
)
|
49 |
self.batch_size = batch_size
|
50 |
+
self.entailment_contradiction_threshold = entailment_contradiction_threshold
|
51 |
self.model.to(str(self.devices[0]))
|
52 |
|
53 |
id2label = AutoConfig.from_pretrained(model_name_or_path).id2label
|
|
|
58 |
)
|
59 |
|
60 |
def run(self, query: str, documents: List[Document]):
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
|
63 |
+
for i, doc in enumerate(documents):
|
64 |
+
entailment_info = self.get_entailment(premise=doc.content, hypotesis=query)
|
65 |
+
doc.meta["entailment_info"] = entailment_info
|
66 |
+
|
67 |
+
scores += doc.score
|
68 |
+
con, neu, ent = (
|
69 |
+
entailment_info["contradiction"],
|
70 |
+
entailment_info["neutral"],
|
71 |
+
entailment_info["entailment"],
|
72 |
+
)
|
73 |
+
agg_con += con * doc.score
|
74 |
+
agg_neu += neu * doc.score
|
75 |
+
agg_ent += ent * doc.score
|
76 |
+
|
77 |
+
# if in the first documents there is a strong evidence of entailment/contradiction,
|
78 |
+
# there is no need to consider less relevant documents
|
79 |
+
if max(agg_con, agg_ent) / scores > self.entailment_contradiction_threshold:
|
80 |
+
break
|
81 |
+
|
82 |
+
aggregate_entailment_info = {
|
83 |
+
"contradiction": round(agg_con / scores, 2),
|
84 |
+
"neutral": round(agg_neu / scores, 2),
|
85 |
+
"entailment": round(agg_ent / scores, 2),
|
86 |
+
}
|
87 |
+
|
88 |
+
entailment_checker_result = {
|
89 |
+
"documents": documents[: i + 1],
|
90 |
+
"aggregate_entailment_info": aggregate_entailment_info,
|
91 |
+
}
|
92 |
+
|
93 |
+
return entailment_checker_result, "output_1"
|
94 |
+
|
95 |
+
def run_batch(self, queries: List[str], documents: List[Document]):
|
96 |
pass
|
97 |
|
98 |
def get_entailment(self, premise, hypotesis):
|