Hasan Iqbal
commited on
Added LLM Evaluation on Datasets
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .github/workflows/{main.yaml → release.yaml} +2 -2
- .github/workflows/release_patch.yaml +41 -0
- .gitignore +1 -1
- pyproject.toml +5 -0
- src/openfactcheck/app/__init__.py +0 -4
- src/openfactcheck/app/app.py +4 -3
- src/openfactcheck/app/evaluate_factchecker.py +13 -31
- src/openfactcheck/app/evaluate_llm.py +199 -0
- src/openfactcheck/app/evaluate_response.py +21 -6
- src/openfactcheck/core/base.py +61 -116
- src/openfactcheck/core/cli.py +8 -1
- src/openfactcheck/core/solver.py +12 -2
- src/openfactcheck/{factchecker → data}/__init__.py +0 -0
- src/openfactcheck/data/factcheck-bench.jsonl +0 -0
- src/openfactcheck/data/factoolqa.jsonl +0 -0
- src/openfactcheck/data/factscore-bio.jsonl +0 -0
- src/openfactcheck/data/felm-wk.jsonl +0 -0
- src/openfactcheck/data/freshqa.jsonl +0 -0
- src/openfactcheck/data/selfaware.jsonl +0 -0
- src/openfactcheck/data/snowballing.jsonl +0 -0
- src/openfactcheck/{solvers/factool/ftool_utils → evaluator}/__init__.py +0 -0
- src/openfactcheck/evaluator/factchecker/__init__.py +0 -0
- src/openfactcheck/{factchecker → evaluator/factchecker}/evaluate.py +38 -28
- src/openfactcheck/evaluator/llm/__init__.py +0 -0
- src/openfactcheck/evaluator/llm/evaluate.py +372 -0
- src/openfactcheck/evaluator/llm/evaluate_freetext.py +305 -0
- src/openfactcheck/evaluator/llm/evaluate_freshqa.py +308 -0
- src/openfactcheck/evaluator/llm/evaluate_selfaware.py +284 -0
- src/openfactcheck/evaluator/llm/evaluate_snowballing.py +184 -0
- src/openfactcheck/evaluator/llm/report.py +82 -0
- src/openfactcheck/evaluator/response/__init__.py +0 -0
- src/openfactcheck/evaluator/response/evaluate.py +160 -0
- src/openfactcheck/lib/config.py +11 -10
- src/openfactcheck/solvers/factool/all_pass_abstain_detector.py +4 -3
- src/openfactcheck/solvers/factool/concat_response_regenerator.py +6 -6
- src/openfactcheck/solvers/factool/factool_blackbox.py +8 -23
- src/openfactcheck/solvers/factool/factool_blackbox_post_editor.py +8 -48
- src/openfactcheck/solvers/factool/factool_claim_examiner.py +36 -59
- src/openfactcheck/solvers/factool/factool_decontextualizer.py +12 -29
- src/openfactcheck/solvers/factool/factool_evidence_retriever.py +18 -34
- src/openfactcheck/solvers/factool/factool_example.py +0 -18
- src/openfactcheck/solvers/factool/factool_post_editor.py +7 -10
- src/openfactcheck/solvers/factool/factool_utils/__init__.py +0 -0
- src/openfactcheck/solvers/factool/{ftool_utils → factool_utils}/chat_api.py +14 -40
- src/openfactcheck/solvers/factool/factool_utils/prompt.py +69 -0
- src/openfactcheck/solvers/factool/{ftool_utils → factool_utils}/search_api.py +0 -0
- src/openfactcheck/solvers/factool/ftool_utils/prompts.yaml +0 -67
- src/openfactcheck/solvers/webservice/__init__.py +0 -0
- src/openfactcheck/solvers/webservice/factool_utils/chat_api.py +1 -2
- src/openfactcheck/solvers/webservice/ftool_cp.py +0 -3
.github/workflows/{main.yaml → release.yaml}
RENAMED
@@ -1,8 +1,8 @@
|
|
1 |
-
name:
|
2 |
|
3 |
on:
|
4 |
push:
|
5 |
-
tags: ["v[0-9]
|
6 |
branches: ["*"]
|
7 |
|
8 |
jobs:
|
|
|
1 |
+
name: Release
|
2 |
|
3 |
on:
|
4 |
push:
|
5 |
+
tags: ["v[0-9]+.[0-9].0"]
|
6 |
branches: ["*"]
|
7 |
|
8 |
jobs:
|
.github/workflows/release_patch.yaml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Release Patch
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
tags: ["v[0-9]+.[0-9]+.[0-9]+"]
|
6 |
+
branches: ["*"]
|
7 |
+
|
8 |
+
workflow_dispatch:
|
9 |
+
|
10 |
+
jobs:
|
11 |
+
huggingface_deploy:
|
12 |
+
name: Deploy to Hugging Face
|
13 |
+
runs-on: ubuntu-latest
|
14 |
+
if: ${{ github.ref_type == 'tag' && startsWith(github.ref_name, 'v') }}
|
15 |
+
steps:
|
16 |
+
- name: Check out code
|
17 |
+
uses: actions/checkout@v3
|
18 |
+
with:
|
19 |
+
ref: ${{ github.head_ref }}
|
20 |
+
fetch-depth: 0
|
21 |
+
|
22 |
+
- name: Setup Git Config
|
23 |
+
run: |
|
24 |
+
git checkout main
|
25 |
+
git config --global user.email "[email protected]"
|
26 |
+
git config --global user.name "Hasan Iqbal"
|
27 |
+
|
28 |
+
- name: Enable HuggingFace Spaces Config
|
29 |
+
run: |
|
30 |
+
READMEFILE="README.md"
|
31 |
+
sed -i 's/<!--/---/g; s/-->/---/g' $READMEFILE
|
32 |
+
git add $READMEFILE
|
33 |
+
git commit -m "Uncommented HuggingFace Spaces Configuration" || echo "No changes to commit"
|
34 |
+
|
35 |
+
- name: Wait for 10 seconds
|
36 |
+
run: sleep 10
|
37 |
+
|
38 |
+
- name: Push to Hugging Face
|
39 |
+
env:
|
40 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
41 |
+
run: git push https://hasaniqbal777:[email protected]/spaces/hasaniqbal777/OpenFactCheck main --force
|
.gitignore
CHANGED
@@ -12,7 +12,7 @@ ignore.*.tfvars
|
|
12 |
# Project Specific
|
13 |
DEVNOTES
|
14 |
tmp/
|
15 |
-
|
16 |
|
17 |
# Python Specific
|
18 |
.venv/
|
|
|
12 |
# Project Specific
|
13 |
DEVNOTES
|
14 |
tmp/
|
15 |
+
dev*.json
|
16 |
|
17 |
# Python Specific
|
18 |
.venv/
|
pyproject.toml
CHANGED
@@ -8,6 +8,9 @@ version = "0.1.3"
|
|
8 |
authors = [
|
9 |
{name = "Hasan Iqbal", email = "[email protected]"},
|
10 |
{name = "Yuxia Wang", email = "[email protected]"},
|
|
|
|
|
|
|
11 |
]
|
12 |
maintainers = [
|
13 |
{name = "Hasan Iqbal", email = "[email protected]"},
|
@@ -45,5 +48,7 @@ openfactcheck = [
|
|
45 |
"templates/factchecker/*.jsonl",
|
46 |
"templates/factchecker/*.csv",
|
47 |
"templates/factchecker/**/*.jsonl",
|
|
|
|
|
48 |
]
|
49 |
|
|
|
8 |
authors = [
|
9 |
{name = "Hasan Iqbal", email = "[email protected]"},
|
10 |
{name = "Yuxia Wang", email = "[email protected]"},
|
11 |
+
{name = "Minghan Wang", email = "[email protected]"},
|
12 |
+
{name = "Jiahui Geng", email = "[email protected]"},
|
13 |
+
{name = "Preslav Nakov", email = "[email protected]"}
|
14 |
]
|
15 |
maintainers = [
|
16 |
{name = "Hasan Iqbal", email = "[email protected]"},
|
|
|
48 |
"templates/factchecker/*.jsonl",
|
49 |
"templates/factchecker/*.csv",
|
50 |
"templates/factchecker/**/*.jsonl",
|
51 |
+
"templates/llm/*.csv",
|
52 |
+
"templates/report/*.tex"
|
53 |
]
|
54 |
|
src/openfactcheck/app/__init__.py
CHANGED
@@ -1,4 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
|
3 |
-
sys.path.append('../')
|
4 |
-
sys.path.append("/home/user/app/src")
|
|
|
|
|
|
|
|
|
|
src/openfactcheck/app/app.py
CHANGED
@@ -5,6 +5,7 @@ from streamlit_option_menu import option_menu
|
|
5 |
from openfactcheck.core.base import OpenFactCheck, OpenFactCheckConfig
|
6 |
from openfactcheck.app.sidebar import sidebar
|
7 |
from openfactcheck.app.evaluate_response import evaluate_response
|
|
|
8 |
from openfactcheck.app.evaluate_factchecker import evaluate_factchecker
|
9 |
from openfactcheck.app.utils import footer
|
10 |
|
@@ -63,10 +64,10 @@ class App:
|
|
63 |
# Load the selected page
|
64 |
if selected == "Evaluate LLM Response":
|
65 |
evaluate_response(self.ofc)
|
66 |
-
|
67 |
-
|
68 |
elif selected == "Evaluate FactChecker":
|
69 |
-
evaluate_factchecker()
|
70 |
# elif selected == "Leaderboards":
|
71 |
# leaderboards()
|
72 |
# else:
|
|
|
5 |
from openfactcheck.core.base import OpenFactCheck, OpenFactCheckConfig
|
6 |
from openfactcheck.app.sidebar import sidebar
|
7 |
from openfactcheck.app.evaluate_response import evaluate_response
|
8 |
+
from openfactcheck.app.evaluate_llm import evaluate_llm
|
9 |
from openfactcheck.app.evaluate_factchecker import evaluate_factchecker
|
10 |
from openfactcheck.app.utils import footer
|
11 |
|
|
|
64 |
# Load the selected page
|
65 |
if selected == "Evaluate LLM Response":
|
66 |
evaluate_response(self.ofc)
|
67 |
+
elif selected == "Evaluate LLM":
|
68 |
+
evaluate_llm(self.ofc)
|
69 |
elif selected == "Evaluate FactChecker":
|
70 |
+
evaluate_factchecker(self.ofc)
|
71 |
# elif selected == "Leaderboards":
|
72 |
# leaderboards()
|
73 |
# else:
|
src/openfactcheck/app/evaluate_factchecker.py
CHANGED
@@ -8,17 +8,21 @@ import matplotlib.pyplot as plt
|
|
8 |
from importlib import resources as pkg_resources
|
9 |
|
10 |
from openfactcheck.app.utils import metric_card
|
11 |
-
from openfactcheck.
|
12 |
from openfactcheck.templates import factchecker as templates_dir
|
13 |
|
14 |
# Import solver configuration templates
|
15 |
claims_templates_path = str(pkg_resources.files(templates_dir) / "claims.jsonl")
|
16 |
documents_templates_path = str(pkg_resources.files(templates_dir) / "documents.jsonl")
|
17 |
|
18 |
-
def evaluate_factchecker():
|
19 |
"""
|
20 |
This function creates a Streamlit app to evaluate a Factchecker.
|
21 |
"""
|
|
|
|
|
|
|
|
|
22 |
st.write("This is where you can evaluate the factuality of a FactChecker.")
|
23 |
|
24 |
# Display the instructions
|
@@ -111,46 +115,24 @@ def evaluate_factchecker():
|
|
111 |
|
112 |
# Display a waiting message
|
113 |
with st.status("Evaluating factuality of the FactChecker...", expanded=True) as status:
|
114 |
-
|
115 |
-
fce = FactCheckerEvaluator(input=uploaded_data, eval_type="claims")
|
116 |
-
result = fce()
|
117 |
status.update(label="FactChecker evaluated...", state="complete", expanded=False)
|
118 |
|
119 |
# Display the evaluation report
|
120 |
st.write("### Evaluation report:")
|
121 |
-
|
122 |
-
"""
|
123 |
-
{
|
124 |
-
"True_as_positive": {
|
125 |
-
"accuracy": 0.486,
|
126 |
-
"precision": 0.71,
|
127 |
-
"recall": 0.478,
|
128 |
-
"F1": 0.571
|
129 |
-
},
|
130 |
-
"False_as_positive": {
|
131 |
-
"accuracy": 0.486,
|
132 |
-
"precision": 0.277,
|
133 |
-
"recall": 0.506,
|
134 |
-
"F1": 0.358
|
135 |
-
},
|
136 |
-
"total_time": 14430.0,
|
137 |
-
"total_cost": 144.3,
|
138 |
-
"num_samples": 1443
|
139 |
-
}
|
140 |
-
"""
|
141 |
|
142 |
col1, col2 = st.columns(2, gap="large")
|
143 |
with col1:
|
144 |
# Create the heatmap
|
145 |
classes = ['True', 'False']
|
146 |
fig = plt.figure()
|
147 |
-
sns.heatmap(
|
148 |
plt.ylabel('Actual Class')
|
149 |
plt.xlabel('Predicted Class')
|
150 |
st.pyplot(fig)
|
151 |
with col2:
|
152 |
# Display the metrics
|
153 |
-
accuracy =
|
154 |
if accuracy > 0.75 and accuracy <= 1:
|
155 |
# Green background
|
156 |
metric_card(label="Accuracy", value=f"{accuracy:.2%}", background_color="#D4EDDA", border_left_color="#28A745")
|
@@ -163,13 +145,13 @@ def evaluate_factchecker():
|
|
163 |
|
164 |
sub_col1, sub_col2, sub_col3 = st.columns(3)
|
165 |
with sub_col1:
|
166 |
-
metric_card(label="Total Time", value=
|
167 |
with sub_col2:
|
168 |
-
metric_card(label="Total Cost", value=
|
169 |
with sub_col3:
|
170 |
-
metric_card(label="Number of Samples", value=
|
171 |
|
172 |
-
st.text("Report:\n" +
|
173 |
|
174 |
|
175 |
|
|
|
8 |
from importlib import resources as pkg_resources
|
9 |
|
10 |
from openfactcheck.app.utils import metric_card
|
11 |
+
from openfactcheck.core.base import OpenFactCheck
|
12 |
from openfactcheck.templates import factchecker as templates_dir
|
13 |
|
14 |
# Import solver configuration templates
|
15 |
claims_templates_path = str(pkg_resources.files(templates_dir) / "claims.jsonl")
|
16 |
documents_templates_path = str(pkg_resources.files(templates_dir) / "documents.jsonl")
|
17 |
|
18 |
+
def evaluate_factchecker(ofc: OpenFactCheck):
|
19 |
"""
|
20 |
This function creates a Streamlit app to evaluate a Factchecker.
|
21 |
"""
|
22 |
+
|
23 |
+
# Initialize the FactChecker Evaluator
|
24 |
+
fc_evaluator = ofc.FactCheckerEvaluator
|
25 |
+
|
26 |
st.write("This is where you can evaluate the factuality of a FactChecker.")
|
27 |
|
28 |
# Display the instructions
|
|
|
115 |
|
116 |
# Display a waiting message
|
117 |
with st.status("Evaluating factuality of the FactChecker...", expanded=True) as status:
|
118 |
+
result = fc_evaluator.evaluate(input_path=uploaded_data, eval_type="claims")
|
|
|
|
|
119 |
status.update(label="FactChecker evaluated...", state="complete", expanded=False)
|
120 |
|
121 |
# Display the evaluation report
|
122 |
st.write("### Evaluation report:")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
col1, col2 = st.columns(2, gap="large")
|
125 |
with col1:
|
126 |
# Create the heatmap
|
127 |
classes = ['True', 'False']
|
128 |
fig = plt.figure()
|
129 |
+
sns.heatmap(fc_evaluator.confusion_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=classes, yticklabels=classes)
|
130 |
plt.ylabel('Actual Class')
|
131 |
plt.xlabel('Predicted Class')
|
132 |
st.pyplot(fig)
|
133 |
with col2:
|
134 |
# Display the metrics
|
135 |
+
accuracy = fc_evaluator.results["True_as_positive"]["accuracy"]
|
136 |
if accuracy > 0.75 and accuracy <= 1:
|
137 |
# Green background
|
138 |
metric_card(label="Accuracy", value=f"{accuracy:.2%}", background_color="#D4EDDA", border_left_color="#28A745")
|
|
|
145 |
|
146 |
sub_col1, sub_col2, sub_col3 = st.columns(3)
|
147 |
with sub_col1:
|
148 |
+
metric_card(label="Total Time", value=fc_evaluator.results["total_time"])
|
149 |
with sub_col2:
|
150 |
+
metric_card(label="Total Cost", value=fc_evaluator.results["total_cost"])
|
151 |
with sub_col3:
|
152 |
+
metric_card(label="Number of Samples", value=fc_evaluator.results["num_samples"])
|
153 |
|
154 |
+
st.text("Report:\n" + fc_evaluator.classification_report)
|
155 |
|
156 |
|
157 |
|
src/openfactcheck/app/evaluate_llm.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import uuid
|
3 |
+
import zipfile
|
4 |
+
import pandas as pd
|
5 |
+
import streamlit as st
|
6 |
+
from importlib import resources as pkg_resources
|
7 |
+
|
8 |
+
from openfactcheck.core.base import OpenFactCheck
|
9 |
+
from openfactcheck.app.utils import metric_card
|
10 |
+
from openfactcheck.templates import llm as templates_dir
|
11 |
+
|
12 |
+
# Import solver configuration templates
|
13 |
+
questions_templates_path = str(pkg_resources.files(templates_dir) / "questions.csv")
|
14 |
+
|
15 |
+
def evaluate_llm(ofc: OpenFactCheck):
|
16 |
+
"""
|
17 |
+
This function creates a Streamlit app to evaluate the factuality of a LLM.
|
18 |
+
"""
|
19 |
+
# Initialize the LLM Evaluator
|
20 |
+
llm_evaluator = ofc.LLMEvaluator
|
21 |
+
|
22 |
+
st.write("This is where you can evaluate the factuality of a LLM.")
|
23 |
+
|
24 |
+
# Display the instructions
|
25 |
+
st.write("Download the questions and instructions to evaluate the factuality of a LLM.")
|
26 |
+
|
27 |
+
# Check if the file exists
|
28 |
+
if os.path.exists(questions_templates_path):
|
29 |
+
# Create a ZIP file in memory
|
30 |
+
from io import BytesIO
|
31 |
+
memory_file = BytesIO()
|
32 |
+
with zipfile.ZipFile(memory_file, 'w') as zf:
|
33 |
+
# Define the name of the file within the ZIP archive
|
34 |
+
zip_path = os.path.basename(questions_templates_path) # 'questions.csv'
|
35 |
+
# Add file to the ZIP file
|
36 |
+
zf.write(questions_templates_path, arcname=zip_path)
|
37 |
+
|
38 |
+
# Reset pointer to start of the memory file
|
39 |
+
memory_file.seek(0)
|
40 |
+
|
41 |
+
# Create a download button and the file will be downloaded when clicked
|
42 |
+
btn = st.download_button(
|
43 |
+
label="Download",
|
44 |
+
data=memory_file,
|
45 |
+
file_name="openfactcheck_llm_benchmark.zip",
|
46 |
+
mime="application/zip"
|
47 |
+
)
|
48 |
+
else:
|
49 |
+
st.error("File not found.")
|
50 |
+
|
51 |
+
# Display the instructions
|
52 |
+
st.write("Upload the model responses as a JSON file below to evaluate the factuality.")
|
53 |
+
|
54 |
+
# Upload the model output
|
55 |
+
uploaded_file = st.file_uploader("Upload", type=["csv"], label_visibility="collapsed")
|
56 |
+
|
57 |
+
# Check if the file is uploaded
|
58 |
+
if uploaded_file is None:
|
59 |
+
st.info("Please upload a CSV file.")
|
60 |
+
return
|
61 |
+
|
62 |
+
# Check if the file is a CSV file
|
63 |
+
if uploaded_file.type != "text/csv":
|
64 |
+
st.error("Invalid file format. Please upload a CSV file.")
|
65 |
+
return
|
66 |
+
|
67 |
+
# Read the CSV file
|
68 |
+
uploaded_data = pd.read_csv(uploaded_file)
|
69 |
+
|
70 |
+
# Ask user to select datasets they want to evaluate on
|
71 |
+
st.write("Please select the datasets you want to evaluate the LLM on.")
|
72 |
+
datasets = st.multiselect("Select datasets", ["snowballing", "selfaware", "freshqa", "factoolqa", "felm-wk", "factcheck-bench", "factscore-bio"])
|
73 |
+
|
74 |
+
def update_first_name():
|
75 |
+
st.session_state.first_name = st.session_state.input_first_name
|
76 |
+
|
77 |
+
def update_last_name():
|
78 |
+
st.session_state.last_name = st.session_state.input_last_name
|
79 |
+
|
80 |
+
def update_email():
|
81 |
+
st.session_state.email = st.session_state.input_email
|
82 |
+
|
83 |
+
def update_organization():
|
84 |
+
st.session_state.organization = st.session_state.input_organization
|
85 |
+
|
86 |
+
def update_llm_model():
|
87 |
+
st.session_state.llm_model = st.session_state.input_llm_model
|
88 |
+
|
89 |
+
def update_include_in_leaderboard():
|
90 |
+
st.session_state.include_in_leaderboard = st.session_state.input_include_in_leaderboard
|
91 |
+
|
92 |
+
# Display instructions
|
93 |
+
st.write("Please provide the following information to be included in the leaderboard.")
|
94 |
+
|
95 |
+
# Create text inputs to enter the user information
|
96 |
+
st.session_state.id = uuid.uuid4().hex
|
97 |
+
st.text_input("First Name", key="input_first_name", on_change=update_first_name)
|
98 |
+
st.text_input("Last Name", key="input_last_name", on_change=update_last_name)
|
99 |
+
st.text_input("Email", key="input_email", on_change=update_email)
|
100 |
+
st.text_input("LLM Model Name", key="input_llm_model", on_change=update_llm_model)
|
101 |
+
st.text_input("Organization (Optional)", key="input_organization", on_change=update_organization)
|
102 |
+
|
103 |
+
# Create a checkbox to include the user in the leaderboard
|
104 |
+
st.checkbox("Please check this box if you want your LLM to be included in the leaderboard.",
|
105 |
+
key="input_include_in_leaderboard",
|
106 |
+
on_change=update_include_in_leaderboard)
|
107 |
+
|
108 |
+
if st.button("Evaluate LLM"):
|
109 |
+
# Display a success message
|
110 |
+
st.success("User information saved successfully.")
|
111 |
+
|
112 |
+
# Display an information message
|
113 |
+
st.info(f"""Please wait while we evaluate the factuality of the LLM.
|
114 |
+
You will be able to download the evaluation report shortly, if you can wait. The report will also be delivered to your email address.
|
115 |
+
|
116 |
+
Please note your ID {st.session_state.id}, This will be used to track your evaluation.
|
117 |
+
If the report is not available, please contact the administrator and provide your ID.""")
|
118 |
+
|
119 |
+
# Display a waiting message
|
120 |
+
with st.status("Evaluating factuality of the LLM...", expanded=True) as status:
|
121 |
+
# Evaluate the LLM
|
122 |
+
results = llm_evaluator.evaluate(model_name=st.session_state.llm_model,
|
123 |
+
input_path=uploaded_data,
|
124 |
+
datasets=datasets,
|
125 |
+
save_report=False)
|
126 |
+
|
127 |
+
# Get plots
|
128 |
+
st.write("Generating plots...")
|
129 |
+
plots = llm_evaluator.generate_plots(save_plots=False)
|
130 |
+
|
131 |
+
# Generate the evaluation report
|
132 |
+
st.write("Generating evaluation report...")
|
133 |
+
report_path = llm_evaluator.generate_report(report_path=f"{llm_evaluator.output_path}/{llm_evaluator.run_id}")
|
134 |
+
|
135 |
+
status.update(label="LLM evaluated...", state="complete", expanded=False)
|
136 |
+
|
137 |
+
# Display the plots
|
138 |
+
st.write("### Evaluation Report")
|
139 |
+
|
140 |
+
# If snowballing dataset is selected
|
141 |
+
if "snowballing" in datasets:
|
142 |
+
st.write("#### Evaluation on Snowballing Dataset")
|
143 |
+
col1, col2, col3 = st.columns(3)
|
144 |
+
with col1:
|
145 |
+
st.pyplot(plots["snowballing"]["barplot"])
|
146 |
+
with col2:
|
147 |
+
st.pyplot(plots["snowballing"]["cm"])
|
148 |
+
with col3:
|
149 |
+
pass
|
150 |
+
|
151 |
+
# If selfaware dataset is selected
|
152 |
+
if "selfaware" in datasets:
|
153 |
+
st.write("#### Evaluation on SelfAware Dataset")
|
154 |
+
col1, col2, col3 = st.columns(3)
|
155 |
+
with col1:
|
156 |
+
st.pyplot(plots["selfaware"]["barplot"])
|
157 |
+
with col2:
|
158 |
+
st.pyplot(plots["selfaware"]["cm"])
|
159 |
+
with col3:
|
160 |
+
pass
|
161 |
+
|
162 |
+
# If freshqa dataset is selected
|
163 |
+
if "freshqa" in datasets:
|
164 |
+
st.write("#### Evaluation on FreshQA Dataset")
|
165 |
+
col1, col2, col3 = st.columns(3)
|
166 |
+
with col1:
|
167 |
+
st.pyplot(plots["freshqa"]["piechart"])
|
168 |
+
with col2:
|
169 |
+
pass
|
170 |
+
with col3:
|
171 |
+
pass
|
172 |
+
|
173 |
+
# If any of the free-text datasets are selected
|
174 |
+
if any(dataset in ["factoolqa", "felm-wk", "factcheck-bench", "factscore-bio"] for dataset in datasets):
|
175 |
+
st.write("#### Evaluation on Free-Text Datasets")
|
176 |
+
st.pyplot(plots["freetext"]["barplot"])
|
177 |
+
|
178 |
+
# Generate the evaluation report
|
179 |
+
st.write("### Download Evaluation Report")
|
180 |
+
st.info("The report will also be sent to your email address.")
|
181 |
+
|
182 |
+
# Load the evaluation report
|
183 |
+
if os.path.exists(report_path):
|
184 |
+
with open(report_path, "rb") as file:
|
185 |
+
report_bytes = file.read()
|
186 |
+
|
187 |
+
# Display the download button
|
188 |
+
st.download_button(
|
189 |
+
label="Download",
|
190 |
+
data=report_bytes,
|
191 |
+
file_name="llm_evaluation_report.pdf",
|
192 |
+
mime="application/pdf"
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
st.error("File not found.")
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
src/openfactcheck/app/evaluate_response.py
CHANGED
@@ -22,6 +22,9 @@ def evaluate_response(ofc: OpenFactCheck):
|
|
22 |
This function creates a Streamlit app to evaluate the factuality of a LLM response.
|
23 |
"""
|
24 |
|
|
|
|
|
|
|
25 |
# Initialize the solvers
|
26 |
st.session_state.claimprocessors = ofc.list_claimprocessors()
|
27 |
st.session_state.retrievers = ofc.list_retrievers()
|
@@ -67,7 +70,7 @@ def evaluate_response(ofc: OpenFactCheck):
|
|
67 |
# Evaluate the response
|
68 |
st.write("Evaluating response...")
|
69 |
|
70 |
-
response =
|
71 |
st.write("Response evaluated...")
|
72 |
|
73 |
status.update(label="Factuality checked...", state="complete", expanded=False)
|
@@ -148,23 +151,33 @@ def evaluate_response(ofc: OpenFactCheck):
|
|
148 |
claims=0
|
149 |
false_claims = 0
|
150 |
true_claims = 0
|
|
|
|
|
151 |
for i, detail in enumerate(details):
|
152 |
-
|
|
|
|
|
153 |
claim=detail.get("claim", "")
|
154 |
-
if
|
155 |
detail_text += f'##### :red[{str(i+1) + ". " + extract_text(claim)}]'
|
156 |
detail_text += "\n"
|
157 |
claims += 1
|
158 |
false_claims += 1
|
159 |
-
elif
|
160 |
detail_text += f'##### :green[{str(i+1) + ". " + extract_text(claim)}]'
|
161 |
detail_text += "\n"
|
162 |
claims += 1
|
163 |
true_claims += 1
|
|
|
|
|
|
|
|
|
|
|
164 |
else:
|
165 |
-
detail_text += f'##### :
|
166 |
detail_text += "\n"
|
167 |
claims += 1
|
|
|
168 |
else:
|
169 |
st.error("Factuality not found in the verifier output.")
|
170 |
|
@@ -194,7 +207,7 @@ def evaluate_response(ofc: OpenFactCheck):
|
|
194 |
|
195 |
# Generate formatted text with the overall factuality in Markdown format
|
196 |
formatted_text = "### Factuality Detail\n"
|
197 |
-
formatted_text += "Factuality of each claim is color-coded (red
|
198 |
formatted_text += f"{detail_text}\n"
|
199 |
formatted_text += "\n"
|
200 |
|
@@ -202,6 +215,8 @@ def evaluate_response(ofc: OpenFactCheck):
|
|
202 |
with col2:
|
203 |
metric_card(label="Supported Claims", value=true_claims, background_color="#D1ECF1", border_left_color="#17A2B8")
|
204 |
metric_card(label="Conflicted Claims", value=false_claims, background_color="#D1ECF1", border_left_color="#17A2B8")
|
|
|
|
|
205 |
|
206 |
# Get overall factuality (label)
|
207 |
overall_factuality = output_text.get("label", "Unknown")
|
|
|
22 |
This function creates a Streamlit app to evaluate the factuality of a LLM response.
|
23 |
"""
|
24 |
|
25 |
+
# Initialize the response_evaluator
|
26 |
+
response_evaluator = ofc.ResponseEvaluator
|
27 |
+
|
28 |
# Initialize the solvers
|
29 |
st.session_state.claimprocessors = ofc.list_claimprocessors()
|
30 |
st.session_state.retrievers = ofc.list_retrievers()
|
|
|
70 |
# Evaluate the response
|
71 |
st.write("Evaluating response...")
|
72 |
|
73 |
+
response = response_evaluator.evaluate_streaming(st.session_state.input_text)
|
74 |
st.write("Response evaluated...")
|
75 |
|
76 |
status.update(label="Factuality checked...", state="complete", expanded=False)
|
|
|
151 |
claims=0
|
152 |
false_claims = 0
|
153 |
true_claims = 0
|
154 |
+
controversial_claims = 0
|
155 |
+
unverified_claims = 0
|
156 |
for i, detail in enumerate(details):
|
157 |
+
# Get factuality information
|
158 |
+
factuality = str(detail.get("factuality", None))
|
159 |
+
if factuality is not None:
|
160 |
claim=detail.get("claim", "")
|
161 |
+
if factuality == "-1" or factuality == "False":
|
162 |
detail_text += f'##### :red[{str(i+1) + ". " + extract_text(claim)}]'
|
163 |
detail_text += "\n"
|
164 |
claims += 1
|
165 |
false_claims += 1
|
166 |
+
elif factuality == "1" or factuality == "True":
|
167 |
detail_text += f'##### :green[{str(i+1) + ". " + extract_text(claim)}]'
|
168 |
detail_text += "\n"
|
169 |
claims += 1
|
170 |
true_claims += 1
|
171 |
+
elif factuality == "0":
|
172 |
+
detail_text += f'##### :orange[{str(i+1) + ". " + extract_text(claim)}]'
|
173 |
+
detail_text += "\n"
|
174 |
+
claims += 1
|
175 |
+
controversial_claims += 1
|
176 |
else:
|
177 |
+
detail_text += f'##### :purple[{str(i+1) + ". " + extract_text(claim)}]'
|
178 |
detail_text += "\n"
|
179 |
claims += 1
|
180 |
+
unverified_claims += 1
|
181 |
else:
|
182 |
st.error("Factuality not found in the verifier output.")
|
183 |
|
|
|
207 |
|
208 |
# Generate formatted text with the overall factuality in Markdown format
|
209 |
formatted_text = "### Factuality Detail\n"
|
210 |
+
formatted_text += "Factuality of each claim is color-coded (:red[red means false], :green[green means true], :orange[orange means controversial], :violet[violet means unverified]).\n"
|
211 |
formatted_text += f"{detail_text}\n"
|
212 |
formatted_text += "\n"
|
213 |
|
|
|
215 |
with col2:
|
216 |
metric_card(label="Supported Claims", value=true_claims, background_color="#D1ECF1", border_left_color="#17A2B8")
|
217 |
metric_card(label="Conflicted Claims", value=false_claims, background_color="#D1ECF1", border_left_color="#17A2B8")
|
218 |
+
metric_card(label="Controversial Claims", value=controversial_claims, background_color="#D1ECF1", border_left_color="#17A2B8")
|
219 |
+
metric_card(label="Unverified Claims", value=unverified_claims, background_color="#D1ECF1", border_left_color="#17A2B8")
|
220 |
|
221 |
# Get overall factuality (label)
|
222 |
overall_factuality = output_text.get("label", "Unknown")
|
src/openfactcheck/core/base.py
CHANGED
@@ -1,16 +1,18 @@
|
|
1 |
import os
|
2 |
import sys
|
3 |
-
import uuid
|
4 |
-
import tqdm
|
5 |
import json
|
6 |
-
import traceback
|
7 |
from pathlib import Path
|
8 |
-
from typing import
|
9 |
|
10 |
from openfactcheck.lib.logger import logger
|
|
|
11 |
from openfactcheck.lib.config import OpenFactCheckConfig
|
12 |
from openfactcheck.core.solver import SOLVER_REGISTRY, Solver
|
13 |
-
|
|
|
|
|
|
|
|
|
14 |
|
15 |
class OpenFactCheck:
|
16 |
"""
|
@@ -96,24 +98,50 @@ class OpenFactCheck:
|
|
96 |
self.logger.info(f"Loaded solvers: {list(self.list_solvers().keys())}")
|
97 |
|
98 |
# Initialize the pipeline
|
99 |
-
self.
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
@staticmethod
|
108 |
-
def load_solvers(solver_paths):
|
109 |
"""
|
110 |
Load solvers from the given paths
|
111 |
"""
|
112 |
-
for
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
@staticmethod
|
119 |
def list_solvers():
|
@@ -173,9 +201,9 @@ class OpenFactCheck:
|
|
173 |
|
174 |
# Initialize the solver
|
175 |
solver_cls = SOLVER_REGISTRY[solver_name]
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
logger.info(f"Solver {solver_cls(args)} initialized")
|
180 |
|
181 |
return solver_cls(args), solver_cls.input_name, solver_cls.output_name
|
@@ -194,15 +222,19 @@ class OpenFactCheck:
|
|
194 |
"""
|
195 |
Initialize the pipeline with the given configuration
|
196 |
"""
|
197 |
-
pipeline = {}
|
198 |
for required_solver in self.config.pipeline:
|
199 |
if required_solver not in self.solver_configs:
|
200 |
logger.error(f"{required_solver} not in solvers config")
|
201 |
raise RuntimeError(f"{required_solver} not in solvers config")
|
202 |
solver, input_name, output_name = self.init_solver(required_solver, self.solver_configs[required_solver])
|
203 |
-
pipeline[required_solver] = (solver, input_name, output_name)
|
204 |
|
205 |
-
|
|
|
|
|
|
|
|
|
206 |
|
207 |
def init_pipeline_manually(self, pipeline: list):
|
208 |
"""
|
@@ -220,95 +252,8 @@ class OpenFactCheck:
|
|
220 |
solver, input_name, output_name = self.init_solver(required_solver, self.solver_configs[required_solver])
|
221 |
self.pipeline[required_solver] = (solver, input_name, output_name)
|
222 |
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
"
|
227 |
-
|
228 |
-
"state": state.to_dict()
|
229 |
-
}
|
230 |
-
with open(os.path.join(self.output_path, f'{sample_name}.jsonl'), 'a', encoding="utf-8") as f:
|
231 |
-
f.write(json.dumps(result, ensure_ascii=False) + '\n')
|
232 |
-
|
233 |
-
def read_output(self, sample_name):
|
234 |
-
"""
|
235 |
-
Read the output file for the given sample
|
236 |
-
"""
|
237 |
-
with open(os.path.join(self.output_path, f'{sample_name}.jsonl'), 'r', encoding="utf-8") as f:
|
238 |
-
return [json.loads(line) for line in f]
|
239 |
-
|
240 |
-
def remove_output(self, sample_name):
|
241 |
-
"""
|
242 |
-
Remove the output file for the given sample
|
243 |
-
"""
|
244 |
-
os.remove(os.path.join(self.output_path, f'{sample_name}.jsonl'))
|
245 |
-
|
246 |
-
def __call__(self, response: str, question: str = None, stream: bool = False, callback: Callable = None, **kwargs):
|
247 |
-
"""
|
248 |
-
Evaluate the response using the pipeline
|
249 |
-
"""
|
250 |
-
|
251 |
-
def evaluate_response():
|
252 |
-
# Check if sample_name is provided in kwargs else generate a random one
|
253 |
-
sample_name = kwargs.get("sample_name", str(uuid.uuid4().hex[:6]))
|
254 |
-
|
255 |
-
# Initialize the state
|
256 |
-
solver_output = FactCheckerState(question=question, response=response)
|
257 |
-
|
258 |
-
# Initialize the output name
|
259 |
-
output_name = "response"
|
260 |
-
for idx, (name, (solver, input_name, output_name)) in tqdm.tqdm(enumerate(self.pipeline.items()),
|
261 |
-
total=len(self.pipeline)):
|
262 |
-
logger.info(f"Invoking solver: {idx}-{name}")
|
263 |
-
logger.info(f"State content: {solver_output}")
|
264 |
-
|
265 |
-
try:
|
266 |
-
# Solver input is the output of the previous solver
|
267 |
-
solver_input = solver_output
|
268 |
-
|
269 |
-
# Run the solver
|
270 |
-
cont, solver_output = solver(solver_input, **kwargs)
|
271 |
-
|
272 |
-
# Persist the output
|
273 |
-
logger.debug(f"Latest result: {solver_output}")
|
274 |
-
if callback:
|
275 |
-
callback(
|
276 |
-
index=idx,
|
277 |
-
sample_name=sample_name,
|
278 |
-
solver_name=name,
|
279 |
-
input_name=input_name,
|
280 |
-
output_name=output_name,
|
281 |
-
input=solver_input.__dict__,
|
282 |
-
output=solver_output.__dict__,
|
283 |
-
continue_run=cont
|
284 |
-
)
|
285 |
-
|
286 |
-
# Stream the output
|
287 |
-
if stream:
|
288 |
-
yield {
|
289 |
-
"index": idx,
|
290 |
-
"solver_name": name,
|
291 |
-
"input_name": input_name,
|
292 |
-
"output_name": output_name,
|
293 |
-
"input": solver_input.__dict__,
|
294 |
-
"output": solver_output.__dict__,
|
295 |
-
"continue_run": cont
|
296 |
-
}
|
297 |
-
|
298 |
-
self.persist_output(solver_output, idx, name, cont, sample_name=sample_name)
|
299 |
-
|
300 |
-
except:
|
301 |
-
logger.error(f"Error at {traceback.format_exc()}")
|
302 |
-
cont = False
|
303 |
-
output_name = input_name
|
304 |
-
|
305 |
-
# Break if the solver returns False
|
306 |
-
if not cont:
|
307 |
-
logger.info(f"Break at {name}")
|
308 |
-
break
|
309 |
-
|
310 |
-
if not stream:
|
311 |
-
return solver_output.get(output_name)
|
312 |
-
|
313 |
-
# Execute the generator if stream is True, otherwise process normally
|
314 |
-
return evaluate_response()
|
|
|
1 |
import os
|
2 |
import sys
|
|
|
|
|
3 |
import json
|
|
|
4 |
from pathlib import Path
|
5 |
+
from typing import TYPE_CHECKING
|
6 |
|
7 |
from openfactcheck.lib.logger import logger
|
8 |
+
from openfactcheck.core.state import FactCheckerState
|
9 |
from openfactcheck.lib.config import OpenFactCheckConfig
|
10 |
from openfactcheck.core.solver import SOLVER_REGISTRY, Solver
|
11 |
+
|
12 |
+
if TYPE_CHECKING:
|
13 |
+
from openfactcheck.evaluator.llm.evaluate import LLMEvaluator
|
14 |
+
from openfactcheck.evaluator.response.evaluate import ResponseEvaluator
|
15 |
+
from openfactcheck.evaluator.factchecker.evaluate import FactCheckerEvaluator
|
16 |
|
17 |
class OpenFactCheck:
|
18 |
"""
|
|
|
98 |
self.logger.info(f"Loaded solvers: {list(self.list_solvers().keys())}")
|
99 |
|
100 |
# Initialize the pipeline
|
101 |
+
self.init_pipeline()
|
102 |
+
|
103 |
+
@property
|
104 |
+
def LLMEvaluator(self) -> 'LLMEvaluator':
|
105 |
+
"""
|
106 |
+
Return the LLM Evaluator
|
107 |
+
"""
|
108 |
+
from openfactcheck.evaluator.llm.evaluate import LLMEvaluator
|
109 |
+
return LLMEvaluator(self)
|
110 |
+
|
111 |
+
@property
|
112 |
+
def FactCheckerEvaluator(self) -> 'FactCheckerEvaluator':
|
113 |
+
"""
|
114 |
+
Return the FactChecker Evaluator
|
115 |
+
"""
|
116 |
+
from openfactcheck.evaluator.factchecker.evaluate import FactCheckerEvaluator
|
117 |
+
return FactCheckerEvaluator(self)
|
118 |
+
|
119 |
+
@property
|
120 |
+
def ResponseEvaluator(self) -> 'ResponseEvaluator':
|
121 |
+
"""
|
122 |
+
Return the LLM Response Evaluator
|
123 |
+
"""
|
124 |
+
from openfactcheck.evaluator.response.evaluate import ResponseEvaluator
|
125 |
+
return ResponseEvaluator(self)
|
126 |
|
127 |
@staticmethod
|
128 |
+
def load_solvers(solver_paths: dict):
|
129 |
"""
|
130 |
Load solvers from the given paths
|
131 |
"""
|
132 |
+
for key, value in solver_paths.items():
|
133 |
+
if key == "default":
|
134 |
+
for solver_path in value:
|
135 |
+
abs_path = Path(solver_path).resolve()
|
136 |
+
if abs_path.is_dir():
|
137 |
+
sys.path.append(str(abs_path.parent))
|
138 |
+
Solver.load(str(abs_path), f"{abs_path.parent.parent.name}.{abs_path.parent.name}.{abs_path.name}")
|
139 |
+
else:
|
140 |
+
for solver_path in value:
|
141 |
+
abs_path = Path(solver_path).resolve()
|
142 |
+
if abs_path.is_dir():
|
143 |
+
sys.path.append(str(abs_path.parent))
|
144 |
+
Solver.load(str(abs_path), abs_path.name)
|
145 |
|
146 |
@staticmethod
|
147 |
def list_solvers():
|
|
|
201 |
|
202 |
# Initialize the solver
|
203 |
solver_cls = SOLVER_REGISTRY[solver_name]
|
204 |
+
for key, value in args.items():
|
205 |
+
setattr(solver_cls, key, value)
|
206 |
+
|
207 |
logger.info(f"Solver {solver_cls(args)} initialized")
|
208 |
|
209 |
return solver_cls(args), solver_cls.input_name, solver_cls.output_name
|
|
|
222 |
"""
|
223 |
Initialize the pipeline with the given configuration
|
224 |
"""
|
225 |
+
self.pipeline = {}
|
226 |
for required_solver in self.config.pipeline:
|
227 |
if required_solver not in self.solver_configs:
|
228 |
logger.error(f"{required_solver} not in solvers config")
|
229 |
raise RuntimeError(f"{required_solver} not in solvers config")
|
230 |
solver, input_name, output_name = self.init_solver(required_solver, self.solver_configs[required_solver])
|
231 |
+
self.pipeline[required_solver] = (solver, input_name, output_name)
|
232 |
|
233 |
+
self.logger.info("-------------- OpenFactCheck Initialized ----------------")
|
234 |
+
self.logger.info("Pipeline:")
|
235 |
+
for idx, (name, (solver, iname, oname)) in enumerate(self.pipeline.items()):
|
236 |
+
self.logger.info(f"{idx}-{name} ({iname} -> {oname})")
|
237 |
+
self.logger.info("---------------------------------------------------------")
|
238 |
|
239 |
def init_pipeline_manually(self, pipeline: list):
|
240 |
"""
|
|
|
252 |
solver, input_name, output_name = self.init_solver(required_solver, self.solver_configs[required_solver])
|
253 |
self.pipeline[required_solver] = (solver, input_name, output_name)
|
254 |
|
255 |
+
self.logger.info("-------------- OpenFactCheck Initialized ----------------")
|
256 |
+
self.logger.info("Pipeline:")
|
257 |
+
for idx, (name, (solver, iname, oname)) in enumerate(self.pipeline.items()):
|
258 |
+
self.logger.info(f"{idx}-{name} ({iname} -> {oname})")
|
259 |
+
self.logger.info("---------------------------------------------------------")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/openfactcheck/core/cli.py
CHANGED
@@ -19,4 +19,11 @@ def parse_args():
|
|
19 |
if __name__ == "__main__":
|
20 |
args = parse_args()
|
21 |
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
if __name__ == "__main__":
|
20 |
args = parse_args()
|
21 |
|
22 |
+
def callback(index, sample_name, solver_name, input_name, output_name, input, output, continue_run):
|
23 |
+
print(f"Callback: {index}, {sample_name}, {solver_name}, {input_name}, {output_name}, {input}, {output}, {continue_run}")
|
24 |
+
|
25 |
+
config = OpenFactCheckConfig(args.config_path)
|
26 |
+
results = OpenFactCheck(config).LLMEvaluator.evaluate(model_name="gpt2",
|
27 |
+
input_path="src/openfactcheck/templates/llm/responses_test.csv")
|
28 |
+
|
29 |
+
|
src/openfactcheck/core/solver.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import os
|
2 |
import importlib
|
3 |
|
@@ -121,9 +122,18 @@ class Solver:
|
|
121 |
# Get the module name
|
122 |
module_name = namespace + "." + solver_name
|
123 |
|
|
|
|
|
|
|
124 |
# Import the module
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
@staticmethod
|
129 |
def load(path, namespace):
|
|
|
1 |
+
import sys
|
2 |
import os
|
3 |
import importlib
|
4 |
|
|
|
122 |
# Get the module name
|
123 |
module_name = namespace + "." + solver_name
|
124 |
|
125 |
+
# Log the full module name to debug
|
126 |
+
logger.debug(f"Attempting to import {module_name} from {file_path}")
|
127 |
+
|
128 |
# Import the module
|
129 |
+
try:
|
130 |
+
importlib.import_module(module_name)
|
131 |
+
logger.debug(f"Successfully imported {module_name}")
|
132 |
+
except Exception as e:
|
133 |
+
logger.error(f"Failed to import {module_name}: {e}")
|
134 |
+
raise Exception(f"Failed to import {module_name}: {e}")
|
135 |
+
|
136 |
+
return module_name
|
137 |
|
138 |
@staticmethod
|
139 |
def load(path, namespace):
|
src/openfactcheck/{factchecker → data}/__init__.py
RENAMED
File without changes
|
src/openfactcheck/data/factcheck-bench.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/openfactcheck/data/factoolqa.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/openfactcheck/data/factscore-bio.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/openfactcheck/data/felm-wk.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/openfactcheck/data/freshqa.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/openfactcheck/data/selfaware.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/openfactcheck/data/snowballing.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/openfactcheck/{solvers/factool/ftool_utils → evaluator}/__init__.py
RENAMED
File without changes
|
src/openfactcheck/evaluator/factchecker/__init__.py
ADDED
File without changes
|
src/openfactcheck/{factchecker → evaluator/factchecker}/evaluate.py
RENAMED
@@ -1,18 +1,18 @@
|
|
|
|
1 |
import pandas as pd
|
2 |
from typing import Union
|
|
|
3 |
from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score
|
4 |
from sklearn.metrics import classification_report, confusion_matrix
|
5 |
|
6 |
from openfactcheck.core.base import OpenFactCheck
|
7 |
-
from
|
8 |
-
|
9 |
-
from openfactcheck.templates import factchecker as gold_templates_dir
|
10 |
|
11 |
# Import solver configuration templates
|
12 |
-
gold_claims_template_path = str(pkg_resources.files(
|
13 |
-
gold_documents_template_path = str(pkg_resources.files(
|
14 |
|
15 |
-
class FactCheckerEvaluator(
|
16 |
"""
|
17 |
This class is used to evaluate the performance of a FactChecker.
|
18 |
|
@@ -33,7 +33,7 @@ class FactCheckerEvaluator(OpenFactCheck):
|
|
33 |
|
34 |
Attributes
|
35 |
----------
|
36 |
-
|
37 |
The path to the CSV file or the DataFrame containing the FactChecker responses.
|
38 |
gold_path : str
|
39 |
The path to the gold standard file.
|
@@ -48,25 +48,20 @@ class FactCheckerEvaluator(OpenFactCheck):
|
|
48 |
|
49 |
Methods
|
50 |
-------
|
51 |
-
|
52 |
This function evaluates the performance of the FactChecker.
|
53 |
evaluate_binary_classification(y_true, y_pred, pos_label="yes"):
|
54 |
-
|
55 |
"""
|
56 |
-
def __init__(self,
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
elif eval_type == "documents":
|
61 |
-
gold_path = gold_documents_template_path
|
62 |
-
else:
|
63 |
-
raise ValueError("Invalid evaluation type. Please provide a valid evaluation type.")
|
64 |
|
65 |
-
|
66 |
-
self.
|
67 |
-
self.
|
68 |
-
|
69 |
-
# Results
|
70 |
self.results = None
|
71 |
self.confusion_matrix = None
|
72 |
self.classification_report = None
|
@@ -86,19 +81,34 @@ class FactCheckerEvaluator(OpenFactCheck):
|
|
86 |
}
|
87 |
return metrics
|
88 |
|
89 |
-
def
|
90 |
"""
|
91 |
This function evaluates the performance of the FactChecker.
|
92 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
# Load the gold standard file
|
94 |
-
|
|
|
|
|
95 |
|
96 |
-
# Check if the
|
97 |
-
if isinstance(self.
|
98 |
-
df_input = self.
|
99 |
else:
|
100 |
# Read the CSV file
|
101 |
-
df_input = pd.read_csv(self.
|
102 |
|
103 |
# Check if the FactChecker responses have the correct number of columns
|
104 |
assert len(df_input.columns) == 3
|
|
|
1 |
+
import io
|
2 |
import pandas as pd
|
3 |
from typing import Union
|
4 |
+
from importlib import resources as pkg_resources
|
5 |
from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score
|
6 |
from sklearn.metrics import classification_report, confusion_matrix
|
7 |
|
8 |
from openfactcheck.core.base import OpenFactCheck
|
9 |
+
from openfactcheck.templates import factchecker as templates_dir
|
|
|
|
|
10 |
|
11 |
# Import solver configuration templates
|
12 |
+
gold_claims_template_path = str(pkg_resources.files(templates_dir) / "gold/claims.jsonl")
|
13 |
+
gold_documents_template_path = str(pkg_resources.files(templates_dir) / "gold/documents.jsonl")
|
14 |
|
15 |
+
class FactCheckerEvaluator():
|
16 |
"""
|
17 |
This class is used to evaluate the performance of a FactChecker.
|
18 |
|
|
|
33 |
|
34 |
Attributes
|
35 |
----------
|
36 |
+
input_path : Union[str, pd.DataFrame]
|
37 |
The path to the CSV file or the DataFrame containing the FactChecker responses.
|
38 |
gold_path : str
|
39 |
The path to the gold standard file.
|
|
|
48 |
|
49 |
Methods
|
50 |
-------
|
51 |
+
evaluate(input_path: Union[str, pd.DataFrame], eval_type: str, gold_path: str = ""):
|
52 |
This function evaluates the performance of the FactChecker.
|
53 |
evaluate_binary_classification(y_true, y_pred, pos_label="yes"):
|
54 |
+
Evaluate the performance of a binary classification task.
|
55 |
"""
|
56 |
+
def __init__(self, ofc: OpenFactCheck):
|
57 |
+
"""
|
58 |
+
Initialize the FactCheckerEvaluator object.
|
59 |
+
"""
|
|
|
|
|
|
|
|
|
60 |
|
61 |
+
# Set the attributes
|
62 |
+
self.input_path = None
|
63 |
+
self.gold_path = None
|
64 |
+
self.eval_type = None
|
|
|
65 |
self.results = None
|
66 |
self.confusion_matrix = None
|
67 |
self.classification_report = None
|
|
|
81 |
}
|
82 |
return metrics
|
83 |
|
84 |
+
def evaluate(self, input_path: Union[str, pd.DataFrame], eval_type: str, gold_path: str = ""):
|
85 |
"""
|
86 |
This function evaluates the performance of the FactChecker.
|
87 |
"""
|
88 |
+
# Set the input_path, gold_path, and eval_type attributes
|
89 |
+
self.input_path = input_path
|
90 |
+
self.gold_path = gold_path
|
91 |
+
self.eval_type = eval_type
|
92 |
+
|
93 |
+
if self.gold_path == "":
|
94 |
+
if eval_type == "claims":
|
95 |
+
self.gold_path = gold_claims_template_path
|
96 |
+
elif eval_type == "documents":
|
97 |
+
self.gold_path = gold_documents_template_path
|
98 |
+
else:
|
99 |
+
raise ValueError("Invalid evaluation type. Please provide a valid evaluation type.")
|
100 |
+
|
101 |
# Load the gold standard file
|
102 |
+
with open(self.gold_path, "r") as f:
|
103 |
+
json_data = f.read()
|
104 |
+
df_gold = pd.read_json(io.StringIO(json_data), lines=True)
|
105 |
|
106 |
+
# Check if the input_path is a DataFrame
|
107 |
+
if isinstance(self.input_path, pd.DataFrame):
|
108 |
+
df_input = self.input_path
|
109 |
else:
|
110 |
# Read the CSV file
|
111 |
+
df_input = pd.read_csv(self.input_path)
|
112 |
|
113 |
# Check if the FactChecker responses have the correct number of columns
|
114 |
assert len(df_input.columns) == 3
|
src/openfactcheck/evaluator/llm/__init__.py
ADDED
File without changes
|
src/openfactcheck/evaluator/llm/evaluate.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import pandas as pd
|
4 |
+
from importlib import resources as pkg_resources
|
5 |
+
|
6 |
+
from openfactcheck.lib.logger import logger
|
7 |
+
from openfactcheck.core.base import OpenFactCheck
|
8 |
+
from openfactcheck.evaluator.llm.evaluate_snowballing import SnowballingEvaluator
|
9 |
+
from openfactcheck.evaluator.llm.evaluate_selfaware import SelfAwareEvaluator
|
10 |
+
from openfactcheck.evaluator.llm.evaluate_freshqa import FreshQAEvaluator
|
11 |
+
from openfactcheck.evaluator.llm.evaluate_freetext import FreeTextEvaluator
|
12 |
+
from openfactcheck.evaluator.llm.report import create_report
|
13 |
+
|
14 |
+
from openfactcheck import data as data_dir
|
15 |
+
|
16 |
+
# Import LLM Evaluation Dataset
|
17 |
+
default_dataset_path = str(pkg_resources.files(data_dir))
|
18 |
+
default_output_path = "tmp/output/llm_evaluator"
|
19 |
+
|
20 |
+
class LLMEvaluator(SnowballingEvaluator, SelfAwareEvaluator, FreshQAEvaluator, FreeTextEvaluator):
|
21 |
+
"""
|
22 |
+
This class is used to evaluate the performance of a Language Model.
|
23 |
+
|
24 |
+
Parameters
|
25 |
+
----------
|
26 |
+
model_name : str
|
27 |
+
The name of the Language Model.
|
28 |
+
input_path : Union[str, pd.DataFrame]
|
29 |
+
The path to the CSV file or the DataFrame containing the LLM responses.
|
30 |
+
The CSV file should have the following two columns:
|
31 |
+
- index: The index of the response.
|
32 |
+
- response: The response generated by the LLM.
|
33 |
+
output_path : str
|
34 |
+
The path to store the output files.
|
35 |
+
dataset_path : str
|
36 |
+
The path to the dataset file containing the questions.
|
37 |
+
datasets : list
|
38 |
+
The list of datasets to evaluate the LLM on.
|
39 |
+
analyze : bool
|
40 |
+
Whether to analyze the results.
|
41 |
+
save_plots : bool
|
42 |
+
Whether to save the plots.
|
43 |
+
save_report : bool
|
44 |
+
Whether to save the report.
|
45 |
+
|
46 |
+
Attributes
|
47 |
+
----------
|
48 |
+
model_name : str
|
49 |
+
The name of the Language Model.
|
50 |
+
run_id : str
|
51 |
+
The unique identifier for the run.
|
52 |
+
input_path : Union[str, pd.DataFrame]
|
53 |
+
The path to the CSV file or the DataFrame containing the LLM responses.
|
54 |
+
output_path : str
|
55 |
+
The path to store the output files.
|
56 |
+
dataset_path : str
|
57 |
+
The path to the dataset file containing the questions.
|
58 |
+
datasets : list
|
59 |
+
The list of datasets to evaluate the LLM on.
|
60 |
+
combined_result : dict
|
61 |
+
The combined evaluation results for all datasets.
|
62 |
+
|
63 |
+
Methods
|
64 |
+
-------
|
65 |
+
evaluate(model_name: str, input_path: Union[str, pd.DataFrame], output_path: str = "", dataset_path: str = "", datasets: list = ["snowballing"], analyze: bool = True, save_plots: bool = True, save_report: bool = True):
|
66 |
+
This function evaluates the performance of the Language Model.
|
67 |
+
read_input():
|
68 |
+
This function reads the input file and dataset file and returns a DataFrame containing the combined data.
|
69 |
+
filter_responses(df: pd.DataFrame, dataset: str):
|
70 |
+
Filter the responses based on the dataset.
|
71 |
+
generate_plots(fig_path, save_plots=True):
|
72 |
+
Generate plots for the evaluation
|
73 |
+
"""
|
74 |
+
def __init__(self, ofc: OpenFactCheck):
|
75 |
+
SnowballingEvaluator.__init__(self)
|
76 |
+
SelfAwareEvaluator.__init__(self)
|
77 |
+
FreshQAEvaluator.__init__(self)
|
78 |
+
FreeTextEvaluator.__init__(self, ofc)
|
79 |
+
self.logger = logger
|
80 |
+
|
81 |
+
# Set the attributes
|
82 |
+
self.model_name = None
|
83 |
+
self.run_id = None
|
84 |
+
self.input_path = None
|
85 |
+
self.dataset_path = None
|
86 |
+
self.output_path = None
|
87 |
+
self.datasets = None
|
88 |
+
|
89 |
+
self.combined_result = None
|
90 |
+
self.labels = None
|
91 |
+
self.predictions = None
|
92 |
+
|
93 |
+
self.logger.info(f"LLM Evaluator initialized with run_id: {self.run_id}")
|
94 |
+
|
95 |
+
def read_input(self):
|
96 |
+
"""
|
97 |
+
This function reads the input file and dataset file and returns a DataFrame containing the combined data.
|
98 |
+
"""
|
99 |
+
|
100 |
+
# Check if the input_path is a DataFrame
|
101 |
+
if isinstance(self.input_path, pd.DataFrame):
|
102 |
+
df_responses = self.input_path
|
103 |
+
else:
|
104 |
+
# Read the CSV file
|
105 |
+
self.logger.info(f"Reading the LLM responses from {self.input_path}...")
|
106 |
+
df_responses = pd.read_csv(self.input_path)
|
107 |
+
|
108 |
+
# Check the number of columns and if any response is missing
|
109 |
+
assert df_responses.shape[1] == 2, "The LLM responses should have 2 columns."
|
110 |
+
|
111 |
+
# Use the first column as index and rename the index and response column
|
112 |
+
df_responses.set_index(df_responses.columns[0], inplace=True)
|
113 |
+
df_responses.index.name = None
|
114 |
+
df_responses.columns = ["response"]
|
115 |
+
|
116 |
+
# Read the avaliable datasets
|
117 |
+
self.logger.info(f"Reading the dataset from {self.dataset_path}...")
|
118 |
+
df_dataset = pd.DataFrame()
|
119 |
+
# Loop through each file in the directory
|
120 |
+
for filename in os.listdir(self.dataset_path):
|
121 |
+
if filename.endswith('.jsonl'): # Check if the file is a JSONL file
|
122 |
+
file_path = os.path.join(self.dataset_path, filename)
|
123 |
+
|
124 |
+
logger.info(f"Reading {filename.split('.')[0]} dataset...")
|
125 |
+
# Read the JSONL file and append it to the combined DataFrame
|
126 |
+
df = pd.read_json(file_path, lines=True)
|
127 |
+
df_dataset = pd.concat([df_dataset, df], ignore_index=True)
|
128 |
+
|
129 |
+
# Combine the responses and questions
|
130 |
+
assert len(df_responses) == len(df_dataset), "The number of responses and questions should be the same."
|
131 |
+
df_combined = pd.concat([df_dataset, df_responses], axis=1)
|
132 |
+
|
133 |
+
return df_combined
|
134 |
+
|
135 |
+
@staticmethod
|
136 |
+
def filter_responses(df: pd.DataFrame, dataset: str):
|
137 |
+
logger.info(f"Filtering responses for dataset: {dataset}...")
|
138 |
+
# Filter the DataFrame based on the 'source' column directly
|
139 |
+
if dataset == "snowballing":
|
140 |
+
# Filter the DataFrame based on the 'source' column directly
|
141 |
+
filtered_df = df[df['source'] == dataset]
|
142 |
+
|
143 |
+
# Create a new DataFrame with only the required columns
|
144 |
+
responses_df = filtered_df[['topic', 'response']].copy()
|
145 |
+
|
146 |
+
# Convert the DataFrame to a list of dictionaries if needed
|
147 |
+
responses = responses_df.to_dict(orient='records')
|
148 |
+
|
149 |
+
elif dataset == "selfaware":
|
150 |
+
selfaware_subset = [
|
151 |
+
"selfaware-hotpot_train",
|
152 |
+
"selfaware-squadqa_train",
|
153 |
+
"selfaware-triviaqa_train",
|
154 |
+
"selfaware-squadqa_dev",
|
155 |
+
"selfaware-hotpot_dev",
|
156 |
+
"selfaware-triviaqa_dev",
|
157 |
+
"selfaware-SelfAware",
|
158 |
+
]
|
159 |
+
|
160 |
+
responses = []
|
161 |
+
for k, row in df.iterrows():
|
162 |
+
if row["source"] in selfaware_subset:
|
163 |
+
responses.append({
|
164 |
+
"label_unanswerable": row["ability_to_test"].lstrip("answerable: ") == "False",
|
165 |
+
"response": row["response"]})
|
166 |
+
|
167 |
+
elif dataset == "freshqa":
|
168 |
+
responses = []
|
169 |
+
for k, row in df.iterrows():
|
170 |
+
if row["source"] == dataset:
|
171 |
+
responses.append(
|
172 |
+
{
|
173 |
+
"question": row["question"],
|
174 |
+
"reference_answer": row["reference_answer"],
|
175 |
+
"response": row["response"],
|
176 |
+
}
|
177 |
+
)
|
178 |
+
|
179 |
+
elif dataset in ["factoolqa", "felm-wk", "factcheck-bench"]:
|
180 |
+
responses = []
|
181 |
+
for k, row in df.iterrows():
|
182 |
+
if row["source"] == dataset:
|
183 |
+
responses.append(
|
184 |
+
{
|
185 |
+
"source": row["source"],
|
186 |
+
"prompt": row["prompt"],
|
187 |
+
"response": row["response"],
|
188 |
+
}
|
189 |
+
)
|
190 |
+
|
191 |
+
elif dataset == "factscore-bio":
|
192 |
+
factscore_subset = [
|
193 |
+
"factscore-labelled",
|
194 |
+
"factscore-unlabelled",
|
195 |
+
]
|
196 |
+
|
197 |
+
responses = []
|
198 |
+
for k, row in df.iterrows():
|
199 |
+
if row["source"] in factscore_subset:
|
200 |
+
responses.append(
|
201 |
+
{
|
202 |
+
"source": row["source"],
|
203 |
+
"prompt": row["prompt"],
|
204 |
+
"response": row["response"],
|
205 |
+
}
|
206 |
+
)
|
207 |
+
else:
|
208 |
+
raise ValueError(f"Dataset {dataset} is not supported.")
|
209 |
+
|
210 |
+
return responses
|
211 |
+
|
212 |
+
def generate_plots(self, fig_path: str = "", save_plots=True):
|
213 |
+
# Create a bar plot of the accuracy of the LLM responses on the Snowballing dataset
|
214 |
+
# for each topic and the overall accuracy.
|
215 |
+
plots = {}
|
216 |
+
for dataset in self.combined_result:
|
217 |
+
if dataset == "snowballing":
|
218 |
+
plots["snowballing"] = {}
|
219 |
+
plots["snowballing"]["barplot"] = self.snowballing_barplot(self.combined_result[dataset], fig_path, save=save_plots)
|
220 |
+
plots["snowballing"]["cm"] = self.snowballing_cm(self.labels[dataset], self.predictions[dataset], fig_path, save=save_plots)
|
221 |
+
|
222 |
+
elif dataset == "selfaware":
|
223 |
+
plots["selfaware"] = {}
|
224 |
+
plots["selfaware"]["barplot"] = self.selfaware_barplot(self.combined_result[dataset], fig_path, save=save_plots)
|
225 |
+
plots["selfaware"]["cm"] = self.selfaware_cm(self.labels[dataset], self.predictions[dataset], fig_path, save=save_plots)
|
226 |
+
|
227 |
+
elif dataset == "freshqa":
|
228 |
+
plots["freshqa"] = {}
|
229 |
+
plots["freshqa"]["piechart"] = self.freshqa_piechart(self.combined_result[dataset], fig_path, save=save_plots)
|
230 |
+
|
231 |
+
elif dataset == "freetext":
|
232 |
+
plots["freetext"] = {}
|
233 |
+
plots["freetext"]["barplot"] = self.freetext_barplot(self.combined_result["freetext"], fig_path, save=save_plots)
|
234 |
+
|
235 |
+
return plots
|
236 |
+
|
237 |
+
def generate_report(self, report_path: str):
|
238 |
+
# Create a LaTeX report and return the path to the generated PDF
|
239 |
+
return create_report(self.model_name, report_path)
|
240 |
+
|
241 |
+
def evaluate(self,
|
242 |
+
model_name: str,
|
243 |
+
input_path: str,
|
244 |
+
output_path: str = "",
|
245 |
+
dataset_path: str = "",
|
246 |
+
datasets: list = [
|
247 |
+
"snowballing",
|
248 |
+
"selfaware",
|
249 |
+
"freshqa",
|
250 |
+
"factoolqa",
|
251 |
+
"felm-wk",
|
252 |
+
"factcheck-bench",
|
253 |
+
"factscore-bio"
|
254 |
+
],
|
255 |
+
analyze: bool = True,
|
256 |
+
save_report: bool = True):
|
257 |
+
self.logger.info("Evaluating LLM responses...")
|
258 |
+
|
259 |
+
# Set the attributes
|
260 |
+
self.model_name = model_name
|
261 |
+
self.run_id = "123"
|
262 |
+
self.input_path = input_path
|
263 |
+
self.output_path = output_path
|
264 |
+
self.dataset_path = dataset_path
|
265 |
+
self.datasets = datasets
|
266 |
+
|
267 |
+
# Check if the output path is provided (if not, use the default template)
|
268 |
+
if self.output_path == "":
|
269 |
+
self.output_path = default_output_path
|
270 |
+
|
271 |
+
# Check if the output path exists (if not, create it)
|
272 |
+
if not os.path.exists(self.output_path):
|
273 |
+
os.makedirs(f"{self.output_path}/{self.run_id}")
|
274 |
+
|
275 |
+
# Check if the questions path is provided (if not, use the default template)
|
276 |
+
if self.dataset_path == "":
|
277 |
+
self.dataset_path = default_dataset_path
|
278 |
+
|
279 |
+
# Read the input
|
280 |
+
self.logger.info("Reading the input...")
|
281 |
+
df = self.read_input()
|
282 |
+
self.logger.info(f"Combined data contains {len(df)} rows")
|
283 |
+
|
284 |
+
# Evaluate model responses over each dataset
|
285 |
+
self.combined_result = {}
|
286 |
+
self.labels = {}
|
287 |
+
self.predictions = {}
|
288 |
+
for dataset in self.datasets:
|
289 |
+
logger.info(f"Evaluating responses for dataset: {dataset}...")
|
290 |
+
if dataset == "snowballing":
|
291 |
+
# Filter responses based on the dataset
|
292 |
+
responses = self.filter_responses(df, dataset)
|
293 |
+
|
294 |
+
# Evaluate the responses
|
295 |
+
result, labels, preds = self.evaluate_snowballing(responses)
|
296 |
+
|
297 |
+
# Store the output and save the results
|
298 |
+
df_out = pd.DataFrame({"gold_labels": labels, "predictions": preds})
|
299 |
+
df_out.to_json(f"{self.output_path}/{self.run_id}/{dataset}_output.jsonl", orient="records", lines=True)
|
300 |
+
self.combined_result[dataset] = result
|
301 |
+
self.labels[dataset] = labels
|
302 |
+
self.predictions[dataset] = preds
|
303 |
+
|
304 |
+
elif dataset == "selfaware":
|
305 |
+
# Filter responses based on the dataset
|
306 |
+
responses = self.filter_responses(df, dataset)
|
307 |
+
|
308 |
+
# Evaluate the responses
|
309 |
+
result, labels, preds = self.evaluate_selfaware(responses[:30])
|
310 |
+
|
311 |
+
# Store the output and save the results
|
312 |
+
df_out = pd.DataFrame({"gold_labels": labels, "predictions": preds})
|
313 |
+
df_out.to_json(f"{self.output_path}/{self.run_id}/{dataset}_output.jsonl", orient="records", lines=True)
|
314 |
+
self.combined_result[dataset] = result
|
315 |
+
self.labels[dataset] = labels
|
316 |
+
self.predictions[dataset] = preds
|
317 |
+
|
318 |
+
elif dataset == "freshqa":
|
319 |
+
# Filter responses based on the dataset
|
320 |
+
responses = self.filter_responses(df, dataset)
|
321 |
+
|
322 |
+
# Evaluate the responses
|
323 |
+
result, raw_evals, preds = self.evaluate_freshqa(responses[:30])
|
324 |
+
|
325 |
+
# Store the output and save the results
|
326 |
+
df_out = pd.DataFrame({"raw_evaluations": raw_evals, "predictions": preds})
|
327 |
+
df_out.to_json(f"{self.output_path}/{self.run_id}/{dataset}_output.jsonl", orient="records", lines=True)
|
328 |
+
self.combined_result[dataset] = result
|
329 |
+
|
330 |
+
elif dataset in ["factoolqa", "felm-wk", "factcheck-bench", "factscore-bio"]:
|
331 |
+
# Check if the freetext key exists
|
332 |
+
if self.combined_result.get("freetext") is None:
|
333 |
+
self.combined_result["freetext"] = {}
|
334 |
+
|
335 |
+
# Filter responses based on the dataset
|
336 |
+
responses = self.filter_responses(df, dataset)
|
337 |
+
|
338 |
+
# Evaluate the responses
|
339 |
+
results, evaluations = self.evaluate_freetext(responses[:30], self.model_name, self.run_id)
|
340 |
+
|
341 |
+
# Store the output and save the results
|
342 |
+
df_out = pd.DataFrame(evaluations)
|
343 |
+
df_out.to_json(f"{self.output_path}/{self.run_id}/{dataset}_output.jsonl", orient="records", lines=True)
|
344 |
+
self.combined_result["freetext"][dataset] = results
|
345 |
+
|
346 |
+
else:
|
347 |
+
logger.error(f"Dataset {dataset} is not supported.")
|
348 |
+
raise ValueError(f"Dataset {dataset} is not supported.")
|
349 |
+
|
350 |
+
logger.info(f"Finished evaluating responses for dataset: {dataset}")
|
351 |
+
|
352 |
+
# save all evaluation results
|
353 |
+
with open(f"{self.output_path}/{self.run_id}/result.json", "w") as json_file:
|
354 |
+
json.dump(self.combined_result, json_file, indent=4)
|
355 |
+
|
356 |
+
# Analyze the results
|
357 |
+
if analyze:
|
358 |
+
self.logger.info("Analyzing the results...")
|
359 |
+
self.generate_plots(save_plots=True, fig_path=f"{self.output_path}/{self.run_id}")
|
360 |
+
|
361 |
+
# Create a report
|
362 |
+
if save_report:
|
363 |
+
self.logger.info("Creating the report...")
|
364 |
+
self.generate_report(report_path=f"{self.output_path}/{self.run_id}")
|
365 |
+
|
366 |
+
return self.combined_result
|
367 |
+
|
368 |
+
|
369 |
+
|
370 |
+
|
371 |
+
|
372 |
+
|
src/openfactcheck/evaluator/llm/evaluate_freetext.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import json
|
4 |
+
import math
|
5 |
+
import pandas as pd
|
6 |
+
import seaborn as sns
|
7 |
+
from hashlib import md5
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
from openfactcheck.lib.logger import logger
|
11 |
+
from openfactcheck.core.base import OpenFactCheck
|
12 |
+
|
13 |
+
class FreeTextEvaluator():
|
14 |
+
def __init__(self, ofc: OpenFactCheck):
|
15 |
+
"""
|
16 |
+
Initialize the FreeTextEvaluator object.
|
17 |
+
"""
|
18 |
+
|
19 |
+
self.logger = logger
|
20 |
+
|
21 |
+
# Set the OpenFactCheck object
|
22 |
+
self.ofc = ofc
|
23 |
+
|
24 |
+
def calculate_price(self, num_claims, cost_openai=0.015, cost_serper=0.001):
|
25 |
+
"""
|
26 |
+
Calculate the cost (in USD) of the API calls for the free-text experiment.
|
27 |
+
2x API calls per claim
|
28 |
+
|
29 |
+
Parameters
|
30 |
+
----------
|
31 |
+
numClaims : int
|
32 |
+
The number of claims in the free-text experiment.
|
33 |
+
costOpenAI : float
|
34 |
+
The cost of the OpenAI API call.
|
35 |
+
costSerper : float
|
36 |
+
The cost of the Serper API call.
|
37 |
+
"""
|
38 |
+
return num_claims * 2 * (cost_openai + cost_serper)
|
39 |
+
|
40 |
+
def sum_all_elements(self, obj: dict):
|
41 |
+
"""
|
42 |
+
Sum all elements of an object.
|
43 |
+
"""
|
44 |
+
ret = 0
|
45 |
+
for k, v in obj.items():
|
46 |
+
ret += v
|
47 |
+
return ret
|
48 |
+
|
49 |
+
def assess_freetext(self, output_path: str):
|
50 |
+
"""
|
51 |
+
Assess the free-text experiment, i.e., the number and type of claims, this is, Exact Matching (EM).
|
52 |
+
"""
|
53 |
+
|
54 |
+
# Initialize the return object
|
55 |
+
claims = {
|
56 |
+
"num_false_claims": 0,
|
57 |
+
"num_mixed_claims": 0,
|
58 |
+
"num_true_claims": 0,
|
59 |
+
"num_undefined_claims": 0
|
60 |
+
}
|
61 |
+
path = output_path + '/evidence_stance.json'
|
62 |
+
if not os.path.exists(path):
|
63 |
+
return False
|
64 |
+
df = pd.read_json(path, lines=False)
|
65 |
+
dataobj = json.loads(df.to_json())
|
66 |
+
|
67 |
+
# Assess the claims
|
68 |
+
for k, v in dataobj.items():
|
69 |
+
# If stance contains definitive or mixed, then it is false
|
70 |
+
if "definitive" in v["stances"][0] or "mixed" in v["stances"][0]:
|
71 |
+
claims["num_mixed_claims"] += 1
|
72 |
+
elif "factual" in v["stances"][0] or "confirm" in v["stances"][0]:
|
73 |
+
claims["num_true_claims"] += 1
|
74 |
+
elif "error" in v["stances"][0] or "incorrect" in v["stances"][0] or "false" in v["stances"][0]:
|
75 |
+
claims["num_false_claims"] += 1
|
76 |
+
else:
|
77 |
+
claims["num_undefined_claims"] += 1
|
78 |
+
|
79 |
+
return claims
|
80 |
+
|
81 |
+
def read_evaluations(self):
|
82 |
+
"""
|
83 |
+
Read the evaluations from the output directory.
|
84 |
+
"""
|
85 |
+
data = []
|
86 |
+
for dirname in os.listdir(self.base_output_path):
|
87 |
+
dirpath = os.path.join(self.base_output_path, dirname)
|
88 |
+
if os.path.isdir(dirpath):
|
89 |
+
if os.path.exists(os.path.join(dirpath, 'evaluation.json')):
|
90 |
+
with open(os.path.join(dirpath, 'evaluation.json'), 'r') as f:
|
91 |
+
data.append(json.load(f))
|
92 |
+
return data
|
93 |
+
|
94 |
+
def read_results(self, evaluations):
|
95 |
+
"""
|
96 |
+
Read the results from the evaluations.
|
97 |
+
"""
|
98 |
+
# Calculate the total cost and time
|
99 |
+
(costs, time_costs, true_claims, false_claims, mixed_claims, undefined_claims, total_claims) = (0, 0, 0, 0, 0, 0, 0)
|
100 |
+
for evaluation in evaluations:
|
101 |
+
total_claims += 1
|
102 |
+
|
103 |
+
# Calculate the costs
|
104 |
+
costs += self.calculate_price(self.sum_all_elements(evaluation["claims"]))
|
105 |
+
time_costs += evaluation["end"] - evaluation["start"]
|
106 |
+
|
107 |
+
# Calculate the number of claims
|
108 |
+
false_claims += evaluation["claims"]["num_false_claims"]
|
109 |
+
mixed_claims += evaluation["claims"]["num_mixed_claims"]
|
110 |
+
undefined_claims += evaluation["claims"]["num_undefined_claims"]
|
111 |
+
if (evaluation["claims"]["num_false_claims"] + evaluation["claims"]["num_mixed_claims"]) == 0:
|
112 |
+
true_claims += 1
|
113 |
+
|
114 |
+
return{
|
115 |
+
"Claims": total_claims,
|
116 |
+
"True Claims": true_claims,
|
117 |
+
"False Claims": false_claims,
|
118 |
+
"Mixed Claims": mixed_claims,
|
119 |
+
"Undefined Claims": undefined_claims,
|
120 |
+
"Cost (USD)": costs,
|
121 |
+
"Time (ms)": time_costs,
|
122 |
+
"Percentage of True Responses": round(true_claims / total_claims if total_claims != 0 else 0, 3) * 100,
|
123 |
+
"Percentage of False Responses": round(false_claims / total_claims if total_claims != 0 else 0, 3) * 100
|
124 |
+
}
|
125 |
+
|
126 |
+
def freetext_barplot(self, results, fig_path: str = "", save: bool = False):
|
127 |
+
"""
|
128 |
+
Create a barplot for the free-text evaluation results, ensuring full row utilization.
|
129 |
+
|
130 |
+
Parameters
|
131 |
+
----------
|
132 |
+
results : dict
|
133 |
+
The dictionary of results from the free-text evaluation.
|
134 |
+
fig_path : str
|
135 |
+
The path to save the figure.
|
136 |
+
save : bool
|
137 |
+
Whether to save the figure or not.
|
138 |
+
"""
|
139 |
+
|
140 |
+
# Exclude "Claims" and prepare data
|
141 |
+
metrics = list(next(iter(results.values())).keys())
|
142 |
+
datasets = list(results.keys())
|
143 |
+
|
144 |
+
# Prepare plot data and handle specific conversions
|
145 |
+
plot_data = {}
|
146 |
+
for metric in metrics:
|
147 |
+
if metric == "Claims":
|
148 |
+
continue
|
149 |
+
if metric == "Time (s)":
|
150 |
+
plot_data["Time (min)"] = [results[dataset][metric] / (1000 * 60) for dataset in datasets]
|
151 |
+
elif metric == "Percentage of True Responses":
|
152 |
+
plot_data[metric] = [results[dataset][metric] for dataset in datasets]
|
153 |
+
else:
|
154 |
+
plot_data[metric] = [results[dataset][metric] for dataset in datasets]
|
155 |
+
|
156 |
+
# Define the layout
|
157 |
+
total_metrics = len(plot_data)
|
158 |
+
ncols = 4 # Maximum number of columns per row
|
159 |
+
nrows = (total_metrics + ncols - 1) // ncols # Calculate the required number of rows
|
160 |
+
|
161 |
+
# Creating subplots
|
162 |
+
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 5 * nrows))
|
163 |
+
fig.suptitle('Performance on Free-Text Dataset')
|
164 |
+
|
165 |
+
# Flatten axes array if more than one row
|
166 |
+
axes = axes.flatten() if nrows > 1 else [axes]
|
167 |
+
|
168 |
+
# Generate each bar plot and deactivate unused axes
|
169 |
+
for ax, (metric, values) in zip(axes[:total_metrics], plot_data.items()):
|
170 |
+
bars = ax.bar(datasets, values, color=sns.color_palette("rocket", n_colors=len(datasets)))
|
171 |
+
ax.set_title(metric)
|
172 |
+
ax.set_xticks(range(len(datasets)))
|
173 |
+
ax.set_xticklabels(datasets, rotation=45, ha="right")
|
174 |
+
ax.set_ylabel(metric)
|
175 |
+
|
176 |
+
# Annotate each bar with its value
|
177 |
+
for bar in bars:
|
178 |
+
yval = bar.get_height()
|
179 |
+
ax.text(bar.get_x() + bar.get_width()/2, yval, round(yval, 2),
|
180 |
+
ha='center', va='bottom')
|
181 |
+
|
182 |
+
# Set y-axis limits to accommodate annotations
|
183 |
+
ax.set_ylim(0, max(values) * 1.1)
|
184 |
+
|
185 |
+
# Hide unused axes
|
186 |
+
for ax in axes[total_metrics:]:
|
187 |
+
ax.axis('off')
|
188 |
+
|
189 |
+
# Adjust layout to prevent overlap
|
190 |
+
plt.tight_layout()
|
191 |
+
|
192 |
+
if save:
|
193 |
+
plt.savefig(os.path.join(fig_path, "freetext_barplot.pdf"), format="pdf")
|
194 |
+
plt.savefig(os.path.join(fig_path, "freetext_barplot.png"), format="png")
|
195 |
+
|
196 |
+
# Return the figure
|
197 |
+
return fig
|
198 |
+
|
199 |
+
|
200 |
+
def evaluate_freetext(self, llm_responses: list, model_name: str, run_id: str):
|
201 |
+
"""
|
202 |
+
Evaluate the LLM responses on free-text datasets.
|
203 |
+
Currently, FactoolQA, FELM-WK, FactCheck-Bench and FactScore-Bio datasets are included by default.
|
204 |
+
|
205 |
+
Parameters
|
206 |
+
----------
|
207 |
+
llm_responses : list
|
208 |
+
The LLM responses on the free-text datasets.
|
209 |
+
"""
|
210 |
+
|
211 |
+
# Set the pipeline for the FreeTextEvaluator
|
212 |
+
pipeline = [
|
213 |
+
"all_pass_abstain_detector",
|
214 |
+
"factool_decontextualizer",
|
215 |
+
"factool_evidence_retriever",
|
216 |
+
"factool_claim_examiner",
|
217 |
+
"factool_post_editor",
|
218 |
+
"concat_response_generator"
|
219 |
+
]
|
220 |
+
|
221 |
+
# Initialize the pipeline manually
|
222 |
+
self.ofc.init_pipeline_manually(pipeline=pipeline)
|
223 |
+
|
224 |
+
# Get the dataset name and create DataFrame
|
225 |
+
dataset = llm_responses[0]['source']
|
226 |
+
llm_responses = pd.DataFrame(llm_responses)
|
227 |
+
|
228 |
+
# Save the base_output_path
|
229 |
+
self.base_output_path = f"{self.ofc.output_path}/llm_evaluator/{run_id}/{dataset}"
|
230 |
+
|
231 |
+
# Evaluate the LLM responses
|
232 |
+
for idx, response in llm_responses.iterrows():
|
233 |
+
|
234 |
+
prompt = response['prompt']
|
235 |
+
response = response['response']
|
236 |
+
|
237 |
+
# Set the output path
|
238 |
+
output_path = f"{self.base_output_path}/{idx}_{md5(prompt.encode()).hexdigest()}"
|
239 |
+
if not os.path.exists(output_path):
|
240 |
+
os.makedirs(output_path)
|
241 |
+
|
242 |
+
# If the file was already evaluated, skip it
|
243 |
+
if (os.path.exists(f"{self.base_output_path}/{idx}_{md5(prompt.encode()).hexdigest()}/evaluation.json")):
|
244 |
+
logger.info(f"Skipping the evaluation for prompt {idx} as it was already evaluated.")
|
245 |
+
continue
|
246 |
+
|
247 |
+
# TODO: This should work (but it doesn't)
|
248 |
+
# self.ofc.init_solver("factool_evidence_retriever", {"path_save_evidence": f"{output_path}/evidence.json"})
|
249 |
+
|
250 |
+
# Evaluate the response
|
251 |
+
start = time.time() * 1000
|
252 |
+
_result = self.ofc.ResponseEvaluator.evaluate(
|
253 |
+
response=response,
|
254 |
+
prompt=prompt,
|
255 |
+
sample_name=f"llm_evaluator/{run_id}/truth/{dataset}/{idx}"
|
256 |
+
)
|
257 |
+
end = time.time() * 1000
|
258 |
+
|
259 |
+
# TODO: This is a workaround for the TODO above (move the evidence.json file)
|
260 |
+
if os.path.exists("evidence.json"):
|
261 |
+
os.rename("evidence.json", f"{output_path}/evidence.json")
|
262 |
+
if os.path.exists("evidence_stance.json"):
|
263 |
+
os.rename("evidence_stance.json", f"{output_path}/evidence_stance.json")
|
264 |
+
|
265 |
+
# Assess the free-text experiment
|
266 |
+
claims = self.assess_freetext(output_path)
|
267 |
+
if not claims:
|
268 |
+
self.logger.warning(f'Error in assessing experiment for prompt {idx}')
|
269 |
+
continue
|
270 |
+
|
271 |
+
# Persist the output
|
272 |
+
result = {}
|
273 |
+
result["start"] = math.floor(start)
|
274 |
+
result["end"] = math.floor(end)
|
275 |
+
result["llm"] = model_name
|
276 |
+
result["dataset"] = llm_responses["source"][idx]
|
277 |
+
result["prompt"] = prompt
|
278 |
+
result["claims"] = claims
|
279 |
+
result["result"] = _result
|
280 |
+
|
281 |
+
# Save the result
|
282 |
+
logger.debug(f"Saving the result for prompt {idx} in {output_path}/evaluation.json")
|
283 |
+
with open(f"{output_path}/evaluation.json", "w") as f:
|
284 |
+
json.dump(result, f, indent=4)
|
285 |
+
|
286 |
+
logger.info(f"Evaluated the LLM response for prompt {idx} in {end - start} ms.")
|
287 |
+
|
288 |
+
logger.info(f"Finished evaluating the LLM responses for the {dataset} dataset.")
|
289 |
+
|
290 |
+
# Read the outputs
|
291 |
+
evaluations = self.read_evaluations()
|
292 |
+
|
293 |
+
# Read the results
|
294 |
+
results = self.read_results(evaluations)
|
295 |
+
|
296 |
+
return results, evaluations
|
297 |
+
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
|
302 |
+
|
303 |
+
|
304 |
+
|
305 |
+
|
src/openfactcheck/evaluator/llm/evaluate_freshqa.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import pytz
|
4 |
+
import datetime
|
5 |
+
import pandas as pd
|
6 |
+
import seaborn as sns
|
7 |
+
from openai import OpenAI
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
class FreshQAEvaluator():
|
11 |
+
def __init__(self):
|
12 |
+
self.openai_client = OpenAI()
|
13 |
+
self.openai_model_name = "gpt-4o-mini"
|
14 |
+
self.current_date = datetime.datetime.now(pytz.timezone("America/Los_Angeles")).strftime("%B %d, %Y")
|
15 |
+
self.prefix = (
|
16 |
+
"Please evaluate the response to a question under strict evaluation, where"
|
17 |
+
" no hallucinations, outdated information, or ill-formed answers are"
|
18 |
+
" allowed. Please credit the response only if it provides a confident and"
|
19 |
+
" definitive answer, or the correct answer can be obviously inferred from"
|
20 |
+
" the response. The primary or final answer when standing alone must be"
|
21 |
+
" accurate. Any additional information that is provided must not contradict"
|
22 |
+
" the primary answer or reshape one's perception of it. For false-premise"
|
23 |
+
" questions, the response must point out the presence of a false premise to"
|
24 |
+
" receive credit. For answers that involve names of entities (e.g.,"
|
25 |
+
" people), complete names or commonly recognized names are expected."
|
26 |
+
" Regarding numerical answers, approximate numbers are generally not"
|
27 |
+
" accepted unless explicitly included in the ground-truth answers. A"
|
28 |
+
" response that contains any hallucination, no matter how minor, will not"
|
29 |
+
" receive credit. Furthermore, when the response indicates that the"
|
30 |
+
" information might be outdated, we accept it only if it is evident that"
|
31 |
+
" the knowledge has not changed (e.g., through common sense or well-known"
|
32 |
+
" facts)."
|
33 |
+
)
|
34 |
+
|
35 |
+
self.demo_examples = [
|
36 |
+
{
|
37 |
+
"question": "How old is the world's oldest verified living person?",
|
38 |
+
"correct_answers": [
|
39 |
+
"116 years old",
|
40 |
+
"116",
|
41 |
+
],
|
42 |
+
"response": (
|
43 |
+
f"As of today {self.current_date}, the most up-to-date and relevant"
|
44 |
+
" information regarding this query is as follows. The oldest"
|
45 |
+
" verified living person is Maria Branyas Morera, who was born"
|
46 |
+
" on March 4, 1907, making her 116 years old."
|
47 |
+
),
|
48 |
+
"comment": (
|
49 |
+
"This is a valid question. The primary answer in the response"
|
50 |
+
" (116 years old) is accurate, and all the provided information"
|
51 |
+
" is accurate and up-to-date. Thus, the response is credited."
|
52 |
+
),
|
53 |
+
"evaluation": "correct",
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"question": "When did the UK adopt the Euro?",
|
57 |
+
"correct_answers": ["The United Kingdom has never adopted the Euro."],
|
58 |
+
"response": (
|
59 |
+
"The UK has never adopted the Euro as its official currency. The"
|
60 |
+
" country has retained the British pound sterling (GBP) as its"
|
61 |
+
" currency."
|
62 |
+
),
|
63 |
+
"comment": (
|
64 |
+
"The question contains a false premise. The response debunks the"
|
65 |
+
" question's false premise (i.e., it points out that the UK has"
|
66 |
+
" never adopted the Euro), and all the provided information is"
|
67 |
+
" accurate and up-to-date. Thus, the response is credited."
|
68 |
+
),
|
69 |
+
"evaluation": "correct",
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"question": (
|
73 |
+
"How much longer will Brittney Griner spend in the Russian prison?"
|
74 |
+
),
|
75 |
+
"correct_answers": [
|
76 |
+
"She was released in December 2022 as part of a prisoner swap."
|
77 |
+
],
|
78 |
+
"response": (
|
79 |
+
"I'm sorry, but I have no information to suggest that Brittney"
|
80 |
+
" Griner is currently in a Russian prison. As far as I know, she"
|
81 |
+
" is a professional basketball player who currently plays for the"
|
82 |
+
" Phoenix Mercury in the Women's National Basketball Association"
|
83 |
+
" (WNBA). If you have any more information or context, please let"
|
84 |
+
" me know and I'll do my best to help."
|
85 |
+
),
|
86 |
+
"comment": (
|
87 |
+
"The question contains a false premise. The response fails to"
|
88 |
+
" debunk the question's false premise by not explicitly stating"
|
89 |
+
" that Brittney Griner was already released. Additionally, the"
|
90 |
+
" response lacks a confident and definitive answer. Thus, the"
|
91 |
+
" response is not credited."
|
92 |
+
),
|
93 |
+
"evaluation": "incorrect",
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"question": (
|
97 |
+
"Which language has the third largest population by number of"
|
98 |
+
" native speakers?"
|
99 |
+
),
|
100 |
+
"correct_answers": ["English"],
|
101 |
+
"response": "1. Mandarin 2. Spanish 3. English",
|
102 |
+
"comment": (
|
103 |
+
"This is a valid question. While the correct answer (English) can"
|
104 |
+
" be inferred from the response, the response lacks an explicit"
|
105 |
+
" statement, which is not desirable."
|
106 |
+
" Thus, the response is not credited."
|
107 |
+
),
|
108 |
+
"evaluation": "incorrect",
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"question": (
|
112 |
+
"Has Virginia Woolf's novel about the Ramsay family entered the"
|
113 |
+
" public domain in the United States?"
|
114 |
+
),
|
115 |
+
"correct_answers": ["Yes"],
|
116 |
+
"response": (
|
117 |
+
"1927's To the Lighthouse is in the public domain in the United"
|
118 |
+
" States. The final answer is To the Lighthouse."
|
119 |
+
),
|
120 |
+
"comment": (
|
121 |
+
"This is a valid question. While the primary information in the"
|
122 |
+
" response (1927's To the Lighthouse is in the public domain in"
|
123 |
+
" the United States) is accurate, the final answer when standing"
|
124 |
+
' alone is inaccurate (it should be "Yes" instead of "To the'
|
125 |
+
' Lighthouse"). Thus, the response is not credited.'
|
126 |
+
),
|
127 |
+
"evaluation": "incorrect",
|
128 |
+
},
|
129 |
+
]
|
130 |
+
|
131 |
+
self.demo_questions = [ex["question"] for ex in self.demo_examples]
|
132 |
+
self.demo_evaluations = []
|
133 |
+
|
134 |
+
self.demo_evaluation_template = (
|
135 |
+
"\ncorrect answer(s): {correct_answers}"
|
136 |
+
"\nresponse: {response}"
|
137 |
+
"\ncomment: {comment}"
|
138 |
+
"\nevaluation: {evaluation}"
|
139 |
+
)
|
140 |
+
|
141 |
+
self.evaluation_template = (
|
142 |
+
"\ncorrect answer(s): {correct_answers}" "\nresponse: {response}" "\ncomment: "
|
143 |
+
)
|
144 |
+
|
145 |
+
self.df = pd.DataFrame()
|
146 |
+
|
147 |
+
def call_openai_api(self, prompt, temperature, max_tokens, chat_completions=True):
|
148 |
+
"""
|
149 |
+
Call the OpenAI API to generate responses.
|
150 |
+
"""
|
151 |
+
|
152 |
+
# Generate prompts for demo examples
|
153 |
+
for ex in self.demo_examples:
|
154 |
+
demo_evaluation = self.demo_evaluation_template.format(
|
155 |
+
question=ex["question"],
|
156 |
+
correct_answers=" | ".join(ex["correct_answers"]),
|
157 |
+
response=ex["response"],
|
158 |
+
comment=ex["comment"],
|
159 |
+
evaluation=ex["evaluation"],
|
160 |
+
)
|
161 |
+
self.demo_evaluations.append(demo_evaluation)
|
162 |
+
|
163 |
+
# Call the OpenAI API to generate responses
|
164 |
+
# If chat completions are enabled, use the chat completions endpoint
|
165 |
+
if chat_completions:
|
166 |
+
response = self.openai_client.chat.completions.create(
|
167 |
+
model=self.openai_model_name,
|
168 |
+
temperature=temperature,
|
169 |
+
max_tokens=max_tokens,
|
170 |
+
messages=[
|
171 |
+
{
|
172 |
+
"role": "system",
|
173 |
+
"content": (
|
174 |
+
"You are a helpful assistant. Respond as concisely as"
|
175 |
+
f" possible. Knowledge cutoff: {self.current_date}."
|
176 |
+
),
|
177 |
+
},
|
178 |
+
{"role": "user", "content": "What's today's date?"},
|
179 |
+
{
|
180 |
+
"role": "assistant",
|
181 |
+
"content": f"Today is {self.current_date} in Pacific Standard Time.",
|
182 |
+
},
|
183 |
+
{"role": "user", "content": prompt},
|
184 |
+
],
|
185 |
+
)
|
186 |
+
return response.choices[0].message.content
|
187 |
+
|
188 |
+
# If chat completions are disabled, use the completions endpoint
|
189 |
+
else:
|
190 |
+
response = self.openai_client.completions.create(
|
191 |
+
model=self.openai_model_name,
|
192 |
+
temperature=temperature,
|
193 |
+
max_tokens=max_tokens,
|
194 |
+
prompt=prompt,
|
195 |
+
)
|
196 |
+
return response.choices[0].text
|
197 |
+
|
198 |
+
|
199 |
+
def call_fresheval(self, prefix, question, response, correct_answers, evaluation):
|
200 |
+
"""
|
201 |
+
Call the FreshEval API to evaluate responses.
|
202 |
+
"""
|
203 |
+
|
204 |
+
# Set the parameters for the OpenAI API
|
205 |
+
temperature = 0.0
|
206 |
+
max_tokens = 256
|
207 |
+
chat_completions = True
|
208 |
+
|
209 |
+
# Generate prompts for demo examples
|
210 |
+
demo_prompts = []
|
211 |
+
for q, e in zip(self.demo_questions, self.demo_evaluations):
|
212 |
+
demo_prompts.append(f"\n\n\nquestion: {q}{e}")
|
213 |
+
|
214 |
+
# Generate the fresh evaluation prompt
|
215 |
+
fresheval_demo = "".join(demo_prompts).strip()
|
216 |
+
fresheval_question = f"\n\n\nquestion: {question}{evaluation}"
|
217 |
+
|
218 |
+
# Call the OpenAI API to generate responses
|
219 |
+
fresh_eval = prefix + "\n\n\n" + fresheval_demo + fresheval_question
|
220 |
+
answer = self.call_openai_api(fresh_eval, temperature, max_tokens, chat_completions)
|
221 |
+
|
222 |
+
return answer
|
223 |
+
|
224 |
+
def extract_ratings(self, response):
|
225 |
+
"""
|
226 |
+
Extract the rating from the evaluation response.
|
227 |
+
"""
|
228 |
+
|
229 |
+
# If the eval answer contains either of these three words, considered as 0
|
230 |
+
# including incorrect, not correct, not credited
|
231 |
+
pattern = re.compile(
|
232 |
+
r"\b(?:incorrect|not\s+correct|not\s+credited)\b", re.IGNORECASE
|
233 |
+
)
|
234 |
+
if pattern.search(response):
|
235 |
+
return 0
|
236 |
+
else:
|
237 |
+
return 1
|
238 |
+
|
239 |
+
def freshqa_piechart(self, result, fig_path: str = "", save: bool = False):
|
240 |
+
"""
|
241 |
+
Plot a pie chart of the true and false answers on FreshQA.
|
242 |
+
|
243 |
+
Parameters
|
244 |
+
----------
|
245 |
+
result : dict
|
246 |
+
The evaluation result.
|
247 |
+
fig_path : str
|
248 |
+
The path to save the figure.
|
249 |
+
save : bool, optional
|
250 |
+
Whether to save the figure, by default True.
|
251 |
+
"""
|
252 |
+
|
253 |
+
# Given numbers
|
254 |
+
sizes = [result["accuracy"], 1 - result["accuracy"]]
|
255 |
+
labels = ["True Answer", "False Answer"]
|
256 |
+
colors = [(0, 1, 0, 0.5), (1, 0, 0, 0.5)] # Red and green with 50% transparency
|
257 |
+
|
258 |
+
# Create a new figure
|
259 |
+
fig, ax = plt.subplots()
|
260 |
+
|
261 |
+
# Plot pie chart
|
262 |
+
plt.pie(sizes, labels=labels, colors=colors, autopct="%1.1f%%", startangle=140)
|
263 |
+
plt.title("Performance on FreshQA Dataset")
|
264 |
+
|
265 |
+
# Equal aspect ratio ensures that pie is drawn as a circle
|
266 |
+
plt.axis("equal")
|
267 |
+
|
268 |
+
if save:
|
269 |
+
# Save the figure
|
270 |
+
plt.tight_layout()
|
271 |
+
plt.savefig(os.path.join(fig_path, "freshqa_piechart.pdf"), format="pdf")
|
272 |
+
plt.savefig(os.path.join(fig_path, "freshqa_piechart.png"), format="png")
|
273 |
+
|
274 |
+
# Return the figure
|
275 |
+
return fig
|
276 |
+
|
277 |
+
|
278 |
+
def evaluate_freshqa(self, llm_responses):
|
279 |
+
"""
|
280 |
+
Evaluate the responses generated by the LLM on FreshQA questions.
|
281 |
+
"""
|
282 |
+
|
283 |
+
llm_responses = pd.DataFrame(llm_responses)
|
284 |
+
raw_evals = []
|
285 |
+
preds = []
|
286 |
+
for idx, row in llm_responses.iterrows():
|
287 |
+
evaluation = self.evaluation_template.format(
|
288 |
+
correct_answers=row["reference_answer"],
|
289 |
+
response=row["response"],
|
290 |
+
)
|
291 |
+
|
292 |
+
fresheval = self.call_fresheval(
|
293 |
+
self.prefix,
|
294 |
+
row["question"],
|
295 |
+
row["response"],
|
296 |
+
row["reference_answer"],
|
297 |
+
evaluation,
|
298 |
+
)
|
299 |
+
|
300 |
+
evaluation_rating = self.extract_ratings(fresheval)
|
301 |
+
raw_evals.append(evaluation)
|
302 |
+
preds.append(evaluation_rating)
|
303 |
+
|
304 |
+
# Compute the accuracy (percentage of correct evaluations)
|
305 |
+
accuracy = sum(preds) / len(preds)
|
306 |
+
result = { "accuracy": accuracy }
|
307 |
+
return result, raw_evals, preds
|
308 |
+
|
src/openfactcheck/evaluator/llm/evaluate_selfaware.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
import string
|
5 |
+
import numpy as np
|
6 |
+
import seaborn as sns
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from transformers import AutoTokenizer, AutoModel
|
9 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
10 |
+
from sklearn.metrics import classification_report, confusion_matrix
|
11 |
+
|
12 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
13 |
+
|
14 |
+
class SelfAwareEvaluator():
|
15 |
+
def __init__(self):
|
16 |
+
pass
|
17 |
+
|
18 |
+
def remove_punctuation(self, input_string):
|
19 |
+
"""
|
20 |
+
Remove the punctuation from the input string.
|
21 |
+
"""
|
22 |
+
input_string = input_string.strip().lower()
|
23 |
+
if input_string and input_string[-1] in string.punctuation:
|
24 |
+
return input_string[:-1]
|
25 |
+
return input_string
|
26 |
+
|
27 |
+
|
28 |
+
def cut_sentences(self, content):
|
29 |
+
"""
|
30 |
+
Cut the content into sentences.
|
31 |
+
"""
|
32 |
+
sentences = re.split(r"(\.|\!|\?|。|!|?|\.{6})", content)
|
33 |
+
return sentences
|
34 |
+
|
35 |
+
|
36 |
+
def cut_sub_string(self, input_string, window_size=5, punctuation=".,?!"):
|
37 |
+
"""
|
38 |
+
Cut the input string into sub-strings of a fixed window size.
|
39 |
+
"""
|
40 |
+
input_string = input_string.strip().lower()
|
41 |
+
|
42 |
+
# If the input string is empty, return an empty list
|
43 |
+
if len(input_string) < 2:
|
44 |
+
return [""]
|
45 |
+
|
46 |
+
# Remove the punctuation
|
47 |
+
if input_string[-1] in punctuation:
|
48 |
+
input_string = input_string[:-1]
|
49 |
+
|
50 |
+
# Split the string into words
|
51 |
+
string_list = input_string.split()
|
52 |
+
|
53 |
+
# If the length of the string is less than the window size, return the string
|
54 |
+
length = len(string_list)
|
55 |
+
if length <= window_size:
|
56 |
+
return [input_string]
|
57 |
+
else:
|
58 |
+
res = []
|
59 |
+
for i in range(length - window_size + 1):
|
60 |
+
sub_string = " ".join(string_list[i: i + window_size])
|
61 |
+
if sub_string != "" or sub_string != " ":
|
62 |
+
res.append(sub_string)
|
63 |
+
return res
|
64 |
+
|
65 |
+
|
66 |
+
def group_cosine_similarity(self, model, tokenizer, sentences1, sentences2):
|
67 |
+
"""
|
68 |
+
Calculate the cosine similarity between two groups of sentences.
|
69 |
+
"""
|
70 |
+
# Move the model to the device
|
71 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
72 |
+
model.to(device)
|
73 |
+
|
74 |
+
# Encode the sentences
|
75 |
+
encoded1 = tokenizer(sentences1, padding=True, truncation=True, return_tensors="pt").to(device)
|
76 |
+
encoded2 = tokenizer(sentences2, padding=True, truncation=True, return_tensors="pt").to(device)
|
77 |
+
|
78 |
+
# Get the embeddings
|
79 |
+
embeddings1 = model(**encoded1).pooler_output
|
80 |
+
embeddings2 = model(**encoded2).pooler_output
|
81 |
+
|
82 |
+
# Calculate the cosine similarity
|
83 |
+
similarities = cosine_similarity(embeddings1.detach().cpu().numpy(), embeddings2.detach().cpu().numpy())
|
84 |
+
|
85 |
+
return similarities
|
86 |
+
|
87 |
+
|
88 |
+
def get_unanswerable(self, response, model, tokenizer, threshold=0.75):
|
89 |
+
"""
|
90 |
+
Predict whether the response is unanswerable or not.
|
91 |
+
"""
|
92 |
+
# List of uncertain words
|
93 |
+
uncertain_list = [
|
94 |
+
"The answer is unknown.",
|
95 |
+
"The answer is uncertain.",
|
96 |
+
"The answer is unclear.",
|
97 |
+
"There is no scientific evidence.",
|
98 |
+
"There is no definitive answer.",
|
99 |
+
"There is no right answer.",
|
100 |
+
"There is much debate.",
|
101 |
+
"There is no known case.",
|
102 |
+
"There is no concrete answer to this question.",
|
103 |
+
"There is no public information available.",
|
104 |
+
"It is impossible to know.",
|
105 |
+
"It is impossible to answer.",
|
106 |
+
"It is difficult to predict.",
|
107 |
+
"It is not known.",
|
108 |
+
"We do not know.",
|
109 |
+
"I'm not sure.",
|
110 |
+
]
|
111 |
+
uncertain_list = [self.remove_punctuation(_) for _ in uncertain_list]
|
112 |
+
|
113 |
+
# Check if the response contains uncertain words
|
114 |
+
pred_unanswerable = False
|
115 |
+
response = response.lower()
|
116 |
+
for uncertain in uncertain_list:
|
117 |
+
if uncertain in response:
|
118 |
+
pred_unanswerable = True
|
119 |
+
|
120 |
+
# If the response does not contain uncertain words, check the cosine similarity
|
121 |
+
if pred_unanswerable == False:
|
122 |
+
# Cut the response into sentences
|
123 |
+
sub_sen_list = self.cut_sentences(response)
|
124 |
+
|
125 |
+
# Cut the sentences into sub-strings
|
126 |
+
sub_str_list = []
|
127 |
+
for sub_sen in sub_sen_list:
|
128 |
+
if len(sub_sen) >= 2:
|
129 |
+
sub_str_list.extend(self.cut_sub_string(sub_sen))
|
130 |
+
|
131 |
+
# Calculate the cosine similarity
|
132 |
+
if len(sub_str_list) != 0:
|
133 |
+
similarities = self.group_cosine_similarity(model, tokenizer, sub_str_list, uncertain_list)
|
134 |
+
else:
|
135 |
+
similarities = [0]
|
136 |
+
|
137 |
+
# Check if the maximum similarity is greater than the threshold
|
138 |
+
max_uncertainty = np.max(similarities)
|
139 |
+
|
140 |
+
# If the maximum similarity is greater than the threshold, predict unanswerable
|
141 |
+
if max_uncertainty > threshold:
|
142 |
+
pred_unanswerable = True
|
143 |
+
|
144 |
+
return pred_unanswerable
|
145 |
+
|
146 |
+
def selfaware_barplot(self, result: dict, fig_path: str = "", save: bool = False):
|
147 |
+
"""
|
148 |
+
Create a bar plot of the performance on the SelfAware dataset.
|
149 |
+
|
150 |
+
Parameters
|
151 |
+
----------
|
152 |
+
result : dict
|
153 |
+
The evaluation results for the LLM responses on the SelfAware dataset.
|
154 |
+
fig_path : str
|
155 |
+
The path to save the figure.
|
156 |
+
save : bool, optional
|
157 |
+
Whether to save the figure, by default True.
|
158 |
+
"""
|
159 |
+
|
160 |
+
# Data
|
161 |
+
unanswerable_as_pos = result["unanswerable_as_pos"]
|
162 |
+
answerable_as_pos = result["answerable_as_pos"]
|
163 |
+
|
164 |
+
# Remove support
|
165 |
+
unanswerable_as_pos.pop("support", None)
|
166 |
+
answerable_as_pos.pop("support", None)
|
167 |
+
|
168 |
+
# Extract the accuracy values for each topic
|
169 |
+
metrics = list(unanswerable_as_pos.keys())
|
170 |
+
unanswerable_values = [round(v, 2) for k, v in unanswerable_as_pos.items()]
|
171 |
+
answerable_values = [round(v, 2) for k, v in answerable_as_pos.items()]
|
172 |
+
|
173 |
+
# Create a new figure
|
174 |
+
fig, ax = plt.subplots()
|
175 |
+
|
176 |
+
# Number of groups
|
177 |
+
n_groups = len(metrics)
|
178 |
+
index = np.arange(n_groups)
|
179 |
+
bar_width = 0.35
|
180 |
+
|
181 |
+
# Select two colors from the "rocket" palette
|
182 |
+
colors = sns.color_palette("rocket", n_colors=10)
|
183 |
+
color_unanswerable = colors[1]
|
184 |
+
color_answerable = colors[7]
|
185 |
+
|
186 |
+
# Plotting both sets of data
|
187 |
+
bars1 = ax.bar(index, unanswerable_values, bar_width, label='Unanswerable as Positive', color=color_unanswerable)
|
188 |
+
bars2 = ax.bar(index + bar_width, answerable_values, bar_width, label='Answerable as Positive', color=color_answerable)
|
189 |
+
|
190 |
+
# Adding values on top of each bar
|
191 |
+
for bar in bars1:
|
192 |
+
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), f'{bar.get_height():.2f}',
|
193 |
+
ha='center', va='bottom', color='black', rotation='horizontal')
|
194 |
+
for bar in bars2:
|
195 |
+
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), f'{bar.get_height():.2f}',
|
196 |
+
ha='center', va='bottom', color='black', rotation='horizontal')
|
197 |
+
|
198 |
+
# Set x-axis labels
|
199 |
+
ax.set_xticks(index + bar_width / 2)
|
200 |
+
ax.set_xticklabels(metrics)
|
201 |
+
|
202 |
+
# Set y-axis limits to accommodate annotations
|
203 |
+
ax.set_ylim((0, max(unanswerable_values + answerable_values) + 0.1))
|
204 |
+
|
205 |
+
# Adding labels and title
|
206 |
+
ax.set_xlabel("Metrics")
|
207 |
+
ax.set_ylabel("Performance")
|
208 |
+
ax.set_title("Performance on SelfAware Dataset")
|
209 |
+
ax.legend()
|
210 |
+
|
211 |
+
if save:
|
212 |
+
# Save the figure
|
213 |
+
plt.tight_layout()
|
214 |
+
plt.savefig(os.path.join(fig_path, "selfaware_barplot.pdf"), format="pdf")
|
215 |
+
plt.savefig(os.path.join(fig_path, "selfaware_barplot.png"), format="png")
|
216 |
+
|
217 |
+
# Return the figure
|
218 |
+
return fig
|
219 |
+
|
220 |
+
def selfaware_cm(self, labels: list, preds: list, fig_path: str = "", save: bool = False):
|
221 |
+
"""
|
222 |
+
Create a confusion matrix for the SelfAware dataset.
|
223 |
+
|
224 |
+
Parameters
|
225 |
+
----------
|
226 |
+
labels : list
|
227 |
+
The true labels.
|
228 |
+
preds : list
|
229 |
+
The predicted labels.
|
230 |
+
fig_path : str
|
231 |
+
The path to save the figure.
|
232 |
+
save : bool, optional
|
233 |
+
Whether to save the figure, by default True.
|
234 |
+
"""
|
235 |
+
|
236 |
+
# Create a new figure
|
237 |
+
fig, ax = plt.subplots()
|
238 |
+
|
239 |
+
# Compute confusion matrix
|
240 |
+
cm = sns.heatmap(confusion_matrix(labels, preds), annot=True, fmt="d", cmap="Blues", ax=ax)
|
241 |
+
|
242 |
+
# Adding labels and title
|
243 |
+
plt.xticks(ticks=[0.5, 1.5], labels=["Answerable", "Unanswerable"])
|
244 |
+
plt.yticks(ticks=[0.5, 1.5], labels=["Answerable", "Unanswerable"])
|
245 |
+
plt.ylabel("True label")
|
246 |
+
plt.xlabel("Predicted label")
|
247 |
+
plt.title("Confusion Matrix on SelfAware dataset.")
|
248 |
+
|
249 |
+
if save:
|
250 |
+
# Save the figure
|
251 |
+
plt.tight_layout()
|
252 |
+
plt.savefig(os.path.join(fig_path, "selfaware_cm.pdf"), format="pdf")
|
253 |
+
plt.savefig(os.path.join(fig_path, "selfaware_cm.png"), format="png")
|
254 |
+
|
255 |
+
# Return the figure
|
256 |
+
return fig
|
257 |
+
|
258 |
+
def evaluate_selfaware(self, llm_responses):
|
259 |
+
# Load the model
|
260 |
+
model_name = "princeton-nlp/sup-simcse-roberta-large"
|
261 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
262 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
263 |
+
model = AutoModel.from_pretrained(model_name).to(device)
|
264 |
+
|
265 |
+
# Evaluate the LLM responses
|
266 |
+
labels, preds = [], []
|
267 |
+
for item in llm_responses:
|
268 |
+
# gold label: whether the question is answerable or not.
|
269 |
+
labels.append(item["label_unanswerable"])
|
270 |
+
# identify whether the model response answered the question or not.
|
271 |
+
preds.append(self.get_unanswerable(item["response"], model, tokenizer))
|
272 |
+
|
273 |
+
# Change the key names
|
274 |
+
result = classification_report(labels, preds, output_dict=True, zero_division=0)
|
275 |
+
|
276 |
+
# Renaming keys based on the expected output dictionary structure
|
277 |
+
# Unanswerable as positive class and answerable as negative class
|
278 |
+
if "True" in result:
|
279 |
+
result['unanswerable_as_pos'] = result.pop("True")
|
280 |
+
if "False" in result:
|
281 |
+
result['answerable_as_pos'] = result.pop('False')
|
282 |
+
|
283 |
+
return result, labels, preds
|
284 |
+
|
src/openfactcheck/evaluator/llm/evaluate_snowballing.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import numpy as np
|
4 |
+
import seaborn as sns
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from sklearn.metrics import classification_report, confusion_matrix
|
7 |
+
|
8 |
+
class SnowballingEvaluator():
|
9 |
+
"""
|
10 |
+
Evaluate the LLM responses on the Snowballing dataset.
|
11 |
+
|
12 |
+
Parameters
|
13 |
+
----------
|
14 |
+
LLMEvaluator : class
|
15 |
+
The LLMEvaluator class.
|
16 |
+
|
17 |
+
Methods
|
18 |
+
-------
|
19 |
+
evaluate_snowballing(llm_responses: list):
|
20 |
+
Evaluate the LLM responses on the Snowballing dataset
|
21 |
+
snowballing_barplot(result: dict, fig_path: str, save: bool = False):
|
22 |
+
Create a bar plot of the accuracy of the LLM responses on the Snowballing dataset
|
23 |
+
for each topic and the overall accuracy.
|
24 |
+
get_boolean(response: str, strict=False):
|
25 |
+
Get a boolean value from the response.
|
26 |
+
"""
|
27 |
+
def __init__(self):
|
28 |
+
pass
|
29 |
+
|
30 |
+
def get_boolean(self, response: str, strict=False):
|
31 |
+
"""
|
32 |
+
Get a boolean value from the response.
|
33 |
+
|
34 |
+
"""
|
35 |
+
low_response = response.lower()
|
36 |
+
if strict:
|
37 |
+
if low_response.startswith("yes"):
|
38 |
+
return True
|
39 |
+
elif low_response.startswith("no"):
|
40 |
+
return False
|
41 |
+
return None
|
42 |
+
else:
|
43 |
+
# Check if the response contains any of the specified words
|
44 |
+
pattern = r"{}".format("|".join(["n't", "no"]))
|
45 |
+
if bool(re.search(pattern, response, re.IGNORECASE)):
|
46 |
+
return False
|
47 |
+
else:
|
48 |
+
return True
|
49 |
+
|
50 |
+
def snowballing_barplot(self, result: dict, fig_path: str = "", save: bool = False):
|
51 |
+
"""
|
52 |
+
Create a bar plot of the accuracy of the LLM responses on the Snowballing dataset
|
53 |
+
for each topic and the overall accuracy.
|
54 |
+
|
55 |
+
Parameters
|
56 |
+
----------
|
57 |
+
cresult : dict
|
58 |
+
The evaluation results for the LLM responses on the Snowballing dataset
|
59 |
+
fig_path : str
|
60 |
+
The path to save the figure.
|
61 |
+
save : bool, optional
|
62 |
+
Whether to save the figure, by default True.
|
63 |
+
"""
|
64 |
+
|
65 |
+
# Data
|
66 |
+
items = result.keys()
|
67 |
+
|
68 |
+
# Extract the accuracy values for each topic
|
69 |
+
values = [round(v["accuracy"], 2) for k, v in result.items()]
|
70 |
+
|
71 |
+
# Create a new figure
|
72 |
+
fig, ax = plt.subplots()
|
73 |
+
|
74 |
+
# Plotting
|
75 |
+
bars = sns.barplot(x=items, y=values, palette="rocket", hue=items, ax=ax)
|
76 |
+
|
77 |
+
# Adding values on top of each bar
|
78 |
+
for bar in bars.patches:
|
79 |
+
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height(),
|
80 |
+
f'{bar.get_height():.2f}',
|
81 |
+
ha='center',
|
82 |
+
va='bottom',
|
83 |
+
color='black',
|
84 |
+
rotation='horizontal')
|
85 |
+
|
86 |
+
# Rotating x-axis tick labels
|
87 |
+
plt.xticks(rotation=20)
|
88 |
+
|
89 |
+
# Set y-axis limits to accommodate annotations
|
90 |
+
plt.ylim((0, max(values) + 0.1))
|
91 |
+
|
92 |
+
# Adding labels and title
|
93 |
+
plt.xlabel("Topics")
|
94 |
+
plt.ylabel("Accuracy")
|
95 |
+
plt.title("Performance on Snowballing Dataset.")
|
96 |
+
|
97 |
+
if save:
|
98 |
+
# Save the figure
|
99 |
+
plt.tight_layout()
|
100 |
+
plt.savefig(os.path.join(fig_path, "snowballing_barplot.pdf"), format="pdf")
|
101 |
+
plt.savefig(os.path.join(fig_path, "snowballing_barplot.png"), format="png")
|
102 |
+
|
103 |
+
# Return the figure
|
104 |
+
return fig
|
105 |
+
|
106 |
+
def snowballing_cm(self, labels: list, preds: list, fig_path: str = "", save: bool = False):
|
107 |
+
"""
|
108 |
+
Create a confusion matrix for the Snowballing dataset.
|
109 |
+
|
110 |
+
Parameters
|
111 |
+
----------
|
112 |
+
labels : list
|
113 |
+
The true labels.
|
114 |
+
preds : list
|
115 |
+
The predicted labels.
|
116 |
+
fig_path : str
|
117 |
+
The path to save the figure.
|
118 |
+
save : bool, optional
|
119 |
+
Whether to save the figure, by default True.
|
120 |
+
"""
|
121 |
+
|
122 |
+
# Create a new figure
|
123 |
+
fig, ax = plt.subplots()
|
124 |
+
|
125 |
+
# Plotting
|
126 |
+
cm = sns.heatmap(confusion_matrix(labels, preds), annot=True, fmt="d", cmap="Blues", ax=ax)
|
127 |
+
|
128 |
+
# Adding labels and title
|
129 |
+
plt.xticks(ticks=[0.5, 1.5], labels=["True", "False"])
|
130 |
+
plt.yticks(ticks=[0.5, 1.5], labels=["True", "False"])
|
131 |
+
plt.ylabel("True label")
|
132 |
+
plt.xlabel("Predicted label")
|
133 |
+
plt.title("Confusion Matrix on Snowballing dataset.")
|
134 |
+
|
135 |
+
if save:
|
136 |
+
# Save the figure
|
137 |
+
plt.tight_layout()
|
138 |
+
plt.savefig(os.path.join(fig_path, "snowballing_cm.pdf"), format="pdf")
|
139 |
+
plt.savefig(os.path.join(fig_path, "snowballing_cm.png"), format="png")
|
140 |
+
|
141 |
+
# Return the figure
|
142 |
+
return fig
|
143 |
+
|
144 |
+
def evaluate_snowballing(self, llm_responses: list):
|
145 |
+
"""
|
146 |
+
Evaluate the LLM responses on the Snowballing dataset.
|
147 |
+
"""
|
148 |
+
|
149 |
+
# Store evaluation results for three specific topics and aggregate results
|
150 |
+
# for the entire dataset, indexed by topic names.
|
151 |
+
results = {}
|
152 |
+
|
153 |
+
# Define the ground truth answers for the three specific topics.
|
154 |
+
topic_answers = {
|
155 |
+
"Primality Testing": True,
|
156 |
+
"US Senator Search": True,
|
157 |
+
"Graph Connectivity-Flight Search": False,
|
158 |
+
}
|
159 |
+
|
160 |
+
# Store the responses for each topic.
|
161 |
+
topic_responses = {}
|
162 |
+
for key in topic_answers:
|
163 |
+
topic_responses[key] = []
|
164 |
+
|
165 |
+
# Store the responses for each topic.
|
166 |
+
for item in llm_responses:
|
167 |
+
topic_responses[item["topic"]].append(self.get_boolean(item["response"]))
|
168 |
+
|
169 |
+
# Evaluate the LLM responses
|
170 |
+
labels, preds = [], []
|
171 |
+
for key in topic_answers:
|
172 |
+
# Evaluate the responses for each topic.
|
173 |
+
y_true = [topic_answers[key]] * len(topic_responses[key])
|
174 |
+
y_pred = topic_responses[key]
|
175 |
+
results[key] = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
|
176 |
+
|
177 |
+
# Aggregate the results for the entire dataset.
|
178 |
+
labels += [topic_answers[key]] * len(topic_responses[key])
|
179 |
+
preds += topic_responses[key]
|
180 |
+
|
181 |
+
# Evaluate the responses for the entire dataset.
|
182 |
+
results["All"] = classification_report(labels, preds, output_dict=True, zero_division=0)
|
183 |
+
|
184 |
+
return results, labels, preds
|
src/openfactcheck/evaluator/llm/report.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
from pathlib import Path
|
4 |
+
from jinja2 import Environment, FileSystemLoader
|
5 |
+
|
6 |
+
from importlib import resources as pkg_resources
|
7 |
+
|
8 |
+
from openfactcheck.lib.logger import logger
|
9 |
+
from openfactcheck.templates import report as report_dir
|
10 |
+
|
11 |
+
# Import latex template
|
12 |
+
report_template_path = str(pkg_resources.files(report_dir) / "llm_evaluation_report.tex")
|
13 |
+
|
14 |
+
def create_latex_report(model_name, report_path):
|
15 |
+
"""
|
16 |
+
Fill data in tex templates.
|
17 |
+
"""
|
18 |
+
|
19 |
+
loader = FileSystemLoader(os.path.dirname(report_template_path))
|
20 |
+
env = Environment(loader=loader)
|
21 |
+
data = {
|
22 |
+
"model_name": model_name.replace("_", " "),
|
23 |
+
"snowballing_barplot": "snowballing_barplot.png",
|
24 |
+
"snowballing_cm": "snowballing_cm.png",
|
25 |
+
"selfaware_barplot": "selfaware_barplot.png",
|
26 |
+
"selfaware_cm": "selfaware_cm.png",
|
27 |
+
"freshqa_barplot": "freshqa_barplot.png",
|
28 |
+
"freetext_barplot": "freetext_barplot.png",
|
29 |
+
}
|
30 |
+
template = env.get_template(os.path.basename(report_template_path))
|
31 |
+
latex = template.render(data)
|
32 |
+
with open(Path(report_path) / ("main.tex"), "w", encoding="utf-8") as f:
|
33 |
+
f.write(latex)
|
34 |
+
|
35 |
+
return None
|
36 |
+
|
37 |
+
|
38 |
+
def compile_pdf(report_path):
|
39 |
+
"""
|
40 |
+
Compile the latex file to pdf.
|
41 |
+
"""
|
42 |
+
|
43 |
+
# Change the directory to the report path
|
44 |
+
original_directory = os.getcwd()
|
45 |
+
os.chdir(report_path)
|
46 |
+
|
47 |
+
try:
|
48 |
+
try:
|
49 |
+
# Compile the latex file
|
50 |
+
subprocess.run(["pdflatex", "main.tex"], timeout=60)
|
51 |
+
except subprocess.TimeoutExpired:
|
52 |
+
logger.error("Compilation of the report timed out.")
|
53 |
+
raise Exception("Compilation of the report timed out.")
|
54 |
+
|
55 |
+
|
56 |
+
# Rename the pdf file
|
57 |
+
Path("main.pdf").replace("report.pdf")
|
58 |
+
|
59 |
+
# Remove the auxiliary files
|
60 |
+
for file in Path(".").glob("main*"):
|
61 |
+
file.unlink()
|
62 |
+
|
63 |
+
# Change the directory back to the original
|
64 |
+
os.chdir(original_directory)
|
65 |
+
|
66 |
+
except Exception as e:
|
67 |
+
logger.error(f"Error compiling the report: {e}")
|
68 |
+
raise Exception(f"Error compiling the report: {e}")
|
69 |
+
|
70 |
+
|
71 |
+
def create_report(model_name, report_path):
|
72 |
+
"""
|
73 |
+
Create a pdf report.
|
74 |
+
"""
|
75 |
+
|
76 |
+
logger.debug("Creating the report...")
|
77 |
+
create_latex_report(model_name, report_path)
|
78 |
+
|
79 |
+
logger.debug("Compiling the report...")
|
80 |
+
compile_pdf(report_path)
|
81 |
+
|
82 |
+
return Path(report_path) / "report.pdf"
|
src/openfactcheck/evaluator/response/__init__.py
ADDED
File without changes
|
src/openfactcheck/evaluator/response/evaluate.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import uuid
|
3 |
+
import tqdm
|
4 |
+
import json
|
5 |
+
import traceback
|
6 |
+
from typing import Callable
|
7 |
+
|
8 |
+
from openfactcheck.lib.logger import logger
|
9 |
+
from openfactcheck.core.base import OpenFactCheck
|
10 |
+
from openfactcheck.core.state import FactCheckerState
|
11 |
+
|
12 |
+
class ResponseEvaluator:
|
13 |
+
def __init__(self, ofc: OpenFactCheck):
|
14 |
+
"""
|
15 |
+
Initialize the ResponseEvaluator object.
|
16 |
+
"""
|
17 |
+
|
18 |
+
# Set the OpenFactCheck object
|
19 |
+
self.ofc = ofc
|
20 |
+
|
21 |
+
def persist_output(self, state: FactCheckerState, idx, solver_name, cont, sample_name=0):
|
22 |
+
"""
|
23 |
+
Persist the output of the solver
|
24 |
+
"""
|
25 |
+
result = {
|
26 |
+
"idx": idx,
|
27 |
+
"solver": solver_name,
|
28 |
+
"continue": cont,
|
29 |
+
"state": state.to_dict()
|
30 |
+
}
|
31 |
+
|
32 |
+
# Create the output path
|
33 |
+
output_path = os.path.join(self.ofc.output_path, os.path.dirname(sample_name))
|
34 |
+
if not os.path.exists(output_path):
|
35 |
+
os.makedirs(output_path)
|
36 |
+
|
37 |
+
# Write the output to a file
|
38 |
+
with open(os.path.join(self.ofc.output_path, f'{sample_name}.jsonl'), 'a', encoding="utf-8") as f:
|
39 |
+
f.write(json.dumps(result, ensure_ascii=False) + '\n')
|
40 |
+
|
41 |
+
def read_output(self, sample_name):
|
42 |
+
"""
|
43 |
+
Read the output file for the given sample
|
44 |
+
"""
|
45 |
+
with open(os.path.join(self.ofc.output_path, f'{sample_name}.jsonl'), 'r', encoding="utf-8") as f:
|
46 |
+
return [json.loads(line) for line in f]
|
47 |
+
|
48 |
+
def remove_output(self, sample_name):
|
49 |
+
"""
|
50 |
+
Remove the output file for the given sample
|
51 |
+
"""
|
52 |
+
os.remove(os.path.join(self.ofc.output_path, f'{sample_name}.jsonl'))
|
53 |
+
|
54 |
+
def evaluate(self, response: str, question: str = None, callback: Callable = None, **kwargs):
|
55 |
+
"""
|
56 |
+
Evaluate the response using the pipeline and return the output
|
57 |
+
"""
|
58 |
+
|
59 |
+
# Check if sample_name is provided in kwargs else generate a random one
|
60 |
+
sample_name = kwargs.get("sample_name", str(uuid.uuid4().hex[:6]))
|
61 |
+
|
62 |
+
# Initialize the state
|
63 |
+
solver_output = FactCheckerState(question=question, response=response)
|
64 |
+
|
65 |
+
# Initialize the output name
|
66 |
+
output_name = "response"
|
67 |
+
for idx, (name, (solver, input_name, output_name)) in tqdm.tqdm(enumerate(self.ofc.pipeline.items()),
|
68 |
+
total=len(self.ofc.pipeline)):
|
69 |
+
logger.info(f"Invoking solver: {idx}-{name}")
|
70 |
+
logger.debug(f"State content: {solver_output}")
|
71 |
+
|
72 |
+
try:
|
73 |
+
# Solver input is the output of the previous solver
|
74 |
+
solver_input = solver_output
|
75 |
+
|
76 |
+
# Run the solver
|
77 |
+
cont, solver_output = solver(solver_input, **kwargs)
|
78 |
+
|
79 |
+
# Persist the output
|
80 |
+
logger.debug(f"Latest result: {solver_output}")
|
81 |
+
if callback:
|
82 |
+
callback(
|
83 |
+
index=idx,
|
84 |
+
sample_name=sample_name,
|
85 |
+
solver_name=name,
|
86 |
+
input_name=input_name,
|
87 |
+
output_name=output_name,
|
88 |
+
input=solver_input.__dict__,
|
89 |
+
output=solver_output.__dict__,
|
90 |
+
continue_run=cont
|
91 |
+
)
|
92 |
+
|
93 |
+
self.persist_output(solver_output, idx, name, cont, sample_name=sample_name)
|
94 |
+
|
95 |
+
except:
|
96 |
+
logger.error(f"Error at {traceback.format_exc()}")
|
97 |
+
cont = False
|
98 |
+
output_name = input_name
|
99 |
+
|
100 |
+
# Break if the solver returns False
|
101 |
+
if not cont:
|
102 |
+
logger.info(f"Break at {name}")
|
103 |
+
break
|
104 |
+
|
105 |
+
return solver_output.get(output_name)
|
106 |
+
|
107 |
+
def evaluate_streaming(self, response: str, question: str = None, **kwargs):
|
108 |
+
"""
|
109 |
+
Evaluate the response using the pipeline and stream the output
|
110 |
+
"""
|
111 |
+
|
112 |
+
def evaluate_response():
|
113 |
+
# Check if sample_name is provided in kwargs else generate a random one
|
114 |
+
sample_name = kwargs.get("sample_name", str(uuid.uuid4().hex[:6]))
|
115 |
+
|
116 |
+
# Initialize the state
|
117 |
+
solver_output = FactCheckerState(question=question, response=response)
|
118 |
+
|
119 |
+
# Initialize the output name
|
120 |
+
output_name = "response"
|
121 |
+
for idx, (name, (solver, input_name, output_name)) in tqdm.tqdm(enumerate(self.ofc.pipeline.items()),
|
122 |
+
total=len(self.ofc.pipeline)):
|
123 |
+
logger.info(f"Invoking solver: {idx}-{name}")
|
124 |
+
logger.debug(f"State content: {solver_output}")
|
125 |
+
|
126 |
+
try:
|
127 |
+
# Solver input is the output of the previous solver
|
128 |
+
solver_input = solver_output
|
129 |
+
|
130 |
+
# Run the solver
|
131 |
+
cont, solver_output = solver(solver_input, **kwargs)
|
132 |
+
|
133 |
+
# Persist the output
|
134 |
+
logger.debug(f"Latest result: {solver_output}")
|
135 |
+
|
136 |
+
# Stream the output
|
137 |
+
yield {
|
138 |
+
"index": idx,
|
139 |
+
"solver_name": name,
|
140 |
+
"input_name": input_name,
|
141 |
+
"output_name": output_name,
|
142 |
+
"input": solver_input.__dict__,
|
143 |
+
"output": solver_output.__dict__,
|
144 |
+
"continue_run": cont
|
145 |
+
}
|
146 |
+
|
147 |
+
self.persist_output(solver_output, idx, name, cont, sample_name=sample_name)
|
148 |
+
|
149 |
+
except:
|
150 |
+
logger.error(f"Error at {traceback.format_exc()}")
|
151 |
+
cont = False
|
152 |
+
output_name = input_name
|
153 |
+
|
154 |
+
# Break if the solver returns False
|
155 |
+
if not cont:
|
156 |
+
logger.info(f"Break at {name}")
|
157 |
+
break
|
158 |
+
|
159 |
+
# Execute the generator if stream is True, otherwise process normally
|
160 |
+
return evaluate_response()
|
src/openfactcheck/lib/config.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
import os
|
2 |
import json
|
3 |
import yaml
|
|
|
|
|
|
|
4 |
import datasets
|
5 |
import transformers
|
6 |
from pathlib import Path
|
@@ -21,7 +24,8 @@ with solver_config_templates_path as solver_config_templates_dir_path:
|
|
21 |
# Import default solvers
|
22 |
# TODO: Currently, only webservice solvers are supported as default solvers
|
23 |
solver_templates_paths = [
|
24 |
-
str(pkg_resources.files(solver_templates_dir) / 'webservice')
|
|
|
25 |
]
|
26 |
|
27 |
class OpenFactCheckConfig:
|
@@ -70,9 +74,6 @@ class OpenFactCheckConfig:
|
|
70 |
Secrets = namedtuple("Secrets", ["openai_api_key",
|
71 |
"serper_api_key",
|
72 |
"azure_search_key"])
|
73 |
-
Pipeline = namedtuple("Pipeline", ["claimprocessor",
|
74 |
-
"retriever",
|
75 |
-
"verifier"])
|
76 |
|
77 |
# Define Attributes
|
78 |
self.config = None
|
@@ -113,10 +114,10 @@ class OpenFactCheckConfig:
|
|
113 |
|
114 |
# Initialize template solver paths along with the user-defined solver paths
|
115 |
if 'solver_paths' in self.config:
|
116 |
-
self.solver_paths = solver_templates_paths
|
117 |
else:
|
118 |
self.logger.warning("No solver paths found in the configuration file. Using default solver paths only.")
|
119 |
-
self.solver_paths = solver_templates_paths
|
120 |
|
121 |
# Initialize Output Path
|
122 |
if 'output_path' in self.config:
|
@@ -129,9 +130,7 @@ class OpenFactCheckConfig:
|
|
129 |
|
130 |
# Initialize Pipeline config
|
131 |
if 'pipeline' in self.config:
|
132 |
-
self.pipeline =
|
133 |
-
retriever=self.config['pipeline']['retriever'],
|
134 |
-
verifier=self.config['pipeline']['verifier'])
|
135 |
else:
|
136 |
if self.solver_configs:
|
137 |
solvers = list(self.solver_configs.keys())
|
@@ -147,7 +146,7 @@ class OpenFactCheckConfig:
|
|
147 |
verifier = solver
|
148 |
if claimprocessor and retriever and verifier:
|
149 |
break
|
150 |
-
self.pipeline =
|
151 |
self.logger.warning(f"No pipeline found in the configuration file. Using first solver as default pipeline. ClaimProcessor: {claimprocessor}, Retriever: {retriever}, Verifier: {verifier}")
|
152 |
|
153 |
# Initialize Secrets config
|
@@ -182,6 +181,8 @@ class OpenFactCheckConfig:
|
|
182 |
# Disable Transformers and Datasets logging
|
183 |
transformers.logging.set_verbosity_error()
|
184 |
datasets.logging.set_verbosity_error()
|
|
|
|
|
185 |
|
186 |
except FileNotFoundError:
|
187 |
self.logger.error(f"Config file not found: {self.filename}")
|
|
|
1 |
import os
|
2 |
import json
|
3 |
import yaml
|
4 |
+
import openai
|
5 |
+
import asyncio
|
6 |
+
import logging
|
7 |
import datasets
|
8 |
import transformers
|
9 |
from pathlib import Path
|
|
|
24 |
# Import default solvers
|
25 |
# TODO: Currently, only webservice solvers are supported as default solvers
|
26 |
solver_templates_paths = [
|
27 |
+
str(pkg_resources.files(solver_templates_dir) / 'webservice'),
|
28 |
+
str(pkg_resources.files(solver_templates_dir) / 'factool')
|
29 |
]
|
30 |
|
31 |
class OpenFactCheckConfig:
|
|
|
74 |
Secrets = namedtuple("Secrets", ["openai_api_key",
|
75 |
"serper_api_key",
|
76 |
"azure_search_key"])
|
|
|
|
|
|
|
77 |
|
78 |
# Define Attributes
|
79 |
self.config = None
|
|
|
114 |
|
115 |
# Initialize template solver paths along with the user-defined solver paths
|
116 |
if 'solver_paths' in self.config:
|
117 |
+
self.solver_paths = {"default": solver_templates_paths, "user_defined": self.config['solver_paths']}
|
118 |
else:
|
119 |
self.logger.warning("No solver paths found in the configuration file. Using default solver paths only.")
|
120 |
+
self.solver_paths = {"default": solver_templates_paths, "user_defined": []}
|
121 |
|
122 |
# Initialize Output Path
|
123 |
if 'output_path' in self.config:
|
|
|
130 |
|
131 |
# Initialize Pipeline config
|
132 |
if 'pipeline' in self.config:
|
133 |
+
self.pipeline = self.config['pipeline']
|
|
|
|
|
134 |
else:
|
135 |
if self.solver_configs:
|
136 |
solvers = list(self.solver_configs.keys())
|
|
|
146 |
verifier = solver
|
147 |
if claimprocessor and retriever and verifier:
|
148 |
break
|
149 |
+
self.pipeline = [claimprocessor, retriever, verifier]
|
150 |
self.logger.warning(f"No pipeline found in the configuration file. Using first solver as default pipeline. ClaimProcessor: {claimprocessor}, Retriever: {retriever}, Verifier: {verifier}")
|
151 |
|
152 |
# Initialize Secrets config
|
|
|
181 |
# Disable Transformers and Datasets logging
|
182 |
transformers.logging.set_verbosity_error()
|
183 |
datasets.logging.set_verbosity_error()
|
184 |
+
logging.basicConfig(level=logging.ERROR)
|
185 |
+
logging.getLogger("asyncio").setLevel(logging.CRITICAL)
|
186 |
|
187 |
except FileNotFoundError:
|
188 |
self.logger.error(f"Config file not found: {self.filename}")
|
src/openfactcheck/solvers/factool/all_pass_abstain_detector.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
-
import logging
|
2 |
-
|
3 |
from openfactcheck.core.state import FactCheckerState
|
4 |
from openfactcheck.core.solver import StandardTaskSolver, Solver
|
5 |
|
6 |
-
@Solver.
|
7 |
class AllPassAbstainDetector(StandardTaskSolver):
|
|
|
|
|
|
|
8 |
def __init__(self, args):
|
9 |
super().__init__(args)
|
10 |
|
|
|
|
|
|
|
1 |
from openfactcheck.core.state import FactCheckerState
|
2 |
from openfactcheck.core.solver import StandardTaskSolver, Solver
|
3 |
|
4 |
+
@Solver.register("all_pass_abstain_detector", "response", "response")
|
5 |
class AllPassAbstainDetector(StandardTaskSolver):
|
6 |
+
"""
|
7 |
+
A solver to detect if all the claims are abstained (i.e., no claim is made).
|
8 |
+
"""
|
9 |
def __init__(self, args):
|
10 |
super().__init__(args)
|
11 |
|
src/openfactcheck/solvers/factool/concat_response_regenerator.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
-
from core.
|
2 |
-
from core.
|
3 |
-
from core import register_solver
|
4 |
|
5 |
-
|
6 |
-
@register_solver("concat_response_generator", "claim_info", "output")
|
7 |
class ConcatResponseRegenerator(StandardTaskSolver):
|
|
|
|
|
|
|
8 |
def __init__(self, args):
|
9 |
super().__init__(args)
|
10 |
|
@@ -13,6 +14,5 @@ class ConcatResponseRegenerator(StandardTaskSolver):
|
|
13 |
|
14 |
edited_claims = [v["edited_claims"] for _, v in claim_info.items()]
|
15 |
revised_document = " ".join(edited_claims).strip()
|
16 |
-
# print(revised_document)
|
17 |
state.set(self.output_name, revised_document)
|
18 |
return True, state
|
|
|
1 |
+
from openfactcheck.core.state import FactCheckerState
|
2 |
+
from openfactcheck.core.solver import StandardTaskSolver, Solver
|
|
|
3 |
|
4 |
+
@Solver.register("concat_response_generator", "claim_info", "output")
|
|
|
5 |
class ConcatResponseRegenerator(StandardTaskSolver):
|
6 |
+
"""
|
7 |
+
A solver to concatenate the edited claims into a single document.
|
8 |
+
"""
|
9 |
def __init__(self, args):
|
10 |
super().__init__(args)
|
11 |
|
|
|
14 |
|
15 |
edited_claims = [v["edited_claims"] for _, v in claim_info.items()]
|
16 |
revised_document = " ".join(edited_claims).strip()
|
|
|
17 |
state.set(self.output_name, revised_document)
|
18 |
return True, state
|
src/openfactcheck/solvers/factool/factool_blackbox.py
CHANGED
@@ -1,38 +1,23 @@
|
|
1 |
-
from core import register_solver, FactCheckerState, StandardTaskSolver
|
2 |
from factool import Factool
|
3 |
-
import os
|
4 |
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
#
|
8 |
-
# Factool Solver
|
9 |
-
#
|
10 |
-
# Notes:
|
11 |
-
# Factool requires 3 input parameters: prompt, response, and category.
|
12 |
-
# Category is always set to 'kbqa' (Knowledge Base Question Answering) for the purposes of this project.
|
13 |
-
# Because of employing a pipeline of its own, with specific search engine and analysis tools, Factool requires several API keys to be set as environment variables.
|
14 |
-
# That is:
|
15 |
-
# openai_key - OpenAI API key (https://beta.openai.com/)
|
16 |
-
# serper_key - Serper API key (https://serper.dev/)
|
17 |
-
# scrapper_key - Scrapper API key (https://www.scraperapi.com/)
|
18 |
-
# Additional parameters:
|
19 |
-
# llm_in_use - The OpenAI LLM in use (e.g. gpt-4)
|
20 |
-
#
|
21 |
-
##
|
22 |
-
@register_solver("factool_blackbox", "response", "claim_info")
|
23 |
class FactoolBlackboxSolver(StandardTaskSolver):
|
|
|
|
|
|
|
24 |
def __init__(self, args):
|
25 |
super().__init__(args)
|
26 |
self.input_prompt = args.get("input_prompt", None)
|
27 |
-
self.
|
28 |
-
# self.input_prompt = args["input_prompt"] if "input_prompt" in args else None
|
29 |
-
# self.gpt_model = args["llm_in_use"] if "llm_in_use" in args else "gpt-4"
|
30 |
|
31 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
32 |
prompt = state.get(self.input_prompt)
|
33 |
response = state.get(self.input_name)
|
34 |
|
35 |
-
factool_instance = Factool(self.
|
36 |
|
37 |
inputs = [{"prompt": prompt, "response": response, "category": "kbqa"}]
|
38 |
claim_info = factool_instance.run(inputs)
|
|
|
|
|
1 |
from factool import Factool
|
|
|
2 |
|
3 |
+
from openfactcheck.core.state import FactCheckerState
|
4 |
+
from openfactcheck.core.solver import StandardTaskSolver, Solver
|
5 |
|
6 |
+
@Solver.register("factool_blackbox", "response", "claim_info")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
class FactoolBlackboxSolver(StandardTaskSolver):
|
8 |
+
"""
|
9 |
+
A solver to process the response using the Factool black box model.
|
10 |
+
"""
|
11 |
def __init__(self, args):
|
12 |
super().__init__(args)
|
13 |
self.input_prompt = args.get("input_prompt", None)
|
14 |
+
self.model_name = self.global_config.get("model_name", "gpt-4o")
|
|
|
|
|
15 |
|
16 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
17 |
prompt = state.get(self.input_prompt)
|
18 |
response = state.get(self.input_name)
|
19 |
|
20 |
+
factool_instance = Factool(self.model_name)
|
21 |
|
22 |
inputs = [{"prompt": prompt, "response": response, "category": "kbqa"}]
|
23 |
claim_info = factool_instance.run(inputs)
|
src/openfactcheck/solvers/factool/factool_blackbox_post_editor.py
CHANGED
@@ -1,55 +1,15 @@
|
|
1 |
import json
|
2 |
-
from core import register_solver, FactCheckerState, StandardTaskSolver
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
{
|
7 |
-
'average_claim_level_factuality': 0.0,
|
8 |
-
'average_response_level_factuality': 0.0,
|
9 |
-
'detailed_information': [
|
10 |
-
{
|
11 |
-
'prompt': 'Introduce Graham Neubig',
|
12 |
-
'response': 'Graham Neubig is a professor at MIT',
|
13 |
-
'category': 'kbqa',
|
14 |
-
'claims': [
|
15 |
-
{
|
16 |
-
'claim': 'Graham Neubig is a professor at MIT'
|
17 |
-
}
|
18 |
-
],
|
19 |
-
'queries': [
|
20 |
-
[ 'Is Graham Neubig a professor at MIT?', 'Graham Neubig professorship' ]
|
21 |
-
],
|
22 |
-
'evidences': [
|
23 |
-
{
|
24 |
-
'evidence': [ 'I am an Associate Professor at the Carnegie Mellon University Language Technology Institute in the School of Computer Science, and work with a bunch of great ...', 'Missing: MIT? | Show results with:MIT?', 'EI Seminar - Graham Neubig - Learning to Explain and ...', 'Duration: 57:54', 'Posted: Feb 17, 2023', 'I am an Associate Professor at the Carnegie Mellon University Language Technology Institute in the School of Computer Science, and work with a bunch of great ...', 'My research is concerned with language and its role in human communication. In particular, my long-term research goal is to break down barriers in human-human ...', 'Graham Neubig. Associate Professor. Research Interests: Machine Translation · Natural Language Processing · Spoken Language Processing · Machine Learning. My ...', "I am an Associate Professor of Computer Science at Carnegie Mellon University and CEO of… | Learn more about Graham Neubig's work experience, education, ...", 'Graham Neubig received the B.E. degree from the University of Illinois, Urbana ... He is currently an Assistant Professor with Carnegie Mellon University ...' ],
|
25 |
-
'source': [ 'http://www.phontron.com/', 'http://www.phontron.com/', 'https://youtube.com/watch?v=CtcP5bvODzY', 'https://youtube.com/watch?v=CtcP5bvODzY', 'https://youtube.com/watch?v=CtcP5bvODzY', 'http://www.phontron.com/', 'https://www.phontron.com/research.php', 'https://lti.cs.cmu.edu/people/222217661/graham-neubig', 'https://www.linkedin.com/in/graham-neubig-10b41616b', 'https://ieeexplore.ieee.org/author/37591106000' ]
|
26 |
-
}
|
27 |
-
],
|
28 |
-
'claim_level_factuality': [
|
29 |
-
{
|
30 |
-
'reasoning': 'The given text is non-factual. Multiple pieces of evidence indicate that Graham Neubig is an Associate Professor at the Carnegie Mellon University Language Technology Institute in the School of Computer Science, not at MIT.',
|
31 |
-
'error': 'Graham Neubig is not a professor at MIT.',
|
32 |
-
'correction': 'Graham Neubig is a professor at Carnegie Mellon University.',
|
33 |
-
'factuality': False,
|
34 |
-
'claim': 'Graham Neubig is a professor at MIT'
|
35 |
-
}
|
36 |
-
],
|
37 |
-
'response_level_factuality': False
|
38 |
-
}
|
39 |
-
]
|
40 |
-
}
|
41 |
-
'''
|
42 |
|
43 |
-
|
44 |
-
#
|
45 |
-
# Factool Data Post-Editor
|
46 |
-
#
|
47 |
-
# Notes:
|
48 |
-
# Factool response post-processor. Used to presents the results in human-readable format and to save the analysis in a JSON file.
|
49 |
-
#
|
50 |
-
##
|
51 |
-
@register_solver("factool_blackbox_post_editor", "claim_info", "claim_info")
|
52 |
class FactoolBlackboxPostEditor(StandardTaskSolver):
|
|
|
|
|
|
|
|
|
|
|
53 |
def __init__(self, args):
|
54 |
super().__init__(args)
|
55 |
self.path_save_analysis = args.get("path_save_analysis","factool_evidence_analysis.json")
|
|
|
1 |
import json
|
|
|
2 |
|
3 |
+
from openfactcheck.core.state import FactCheckerState
|
4 |
+
from openfactcheck.core.solver import StandardTaskSolver, Solver
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
@Solver.register("factool_blackbox_post_editor", "claim_info", "claim_info")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
class FactoolBlackboxPostEditor(StandardTaskSolver):
|
8 |
+
"""
|
9 |
+
A solver to post-process the results of the Factool black box model.
|
10 |
+
Used to presents the results in human-readable format and to save the analysis in a JSON file.
|
11 |
+
"""
|
12 |
+
|
13 |
def __init__(self, args):
|
14 |
super().__init__(args)
|
15 |
self.path_save_analysis = args.get("path_save_analysis","factool_evidence_analysis.json")
|
src/openfactcheck/solvers/factool/factool_claim_examiner.py
CHANGED
@@ -1,38 +1,23 @@
|
|
1 |
-
from core.fact_check_state import FactCheckerState
|
2 |
-
from core.task_solver import StandardTaskSolver
|
3 |
-
from core import register_solver
|
4 |
-
from .ftool_utils.chat_api import OpenAIChat
|
5 |
-
import yaml
|
6 |
-
import os
|
7 |
import json
|
8 |
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
# Notes:
|
15 |
-
# - This solver is used to examine the claims in a response.
|
16 |
-
#
|
17 |
-
##
|
18 |
-
@register_solver("factool_claim_examiner", "evidences", "claim_info")
|
19 |
class FactoolClaimExaminer(StandardTaskSolver):
|
|
|
|
|
|
|
20 |
def __init__(self, args):
|
21 |
super().__init__(args)
|
22 |
-
self.
|
23 |
self.path_save_stance = args.get("path_save_stance", "evidence_stance.json")
|
24 |
self.verifications = None
|
25 |
-
self.gpt = OpenAIChat(self.
|
26 |
-
self.verification_prompt =
|
27 |
-
open(
|
28 |
-
os.path.join(
|
29 |
-
os.path.dirname(os.path.abspath(__file__)),
|
30 |
-
"ftool_utils/prompts.yaml",
|
31 |
-
),
|
32 |
-
"r",
|
33 |
-
),
|
34 |
-
yaml.FullLoader,
|
35 |
-
)["verification"]
|
36 |
|
37 |
# async def coro (self, factool_instance, claims_in_response, evidences):
|
38 |
# self.verifications = await factool_instance.pipelines["kbqa_online"]._verification(claims_in_response, evidences)
|
@@ -72,36 +57,28 @@ class FactoolClaimExaminer(StandardTaskSolver):
|
|
72 |
# print(f'Verification for claim {key}: {verifications[index]}\n')
|
73 |
# print(f'Verification for claim {key}: Type = {type(verifications[index])}\n')
|
74 |
stance = ""
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
stance = claims_in_response[index]["claim"]
|
80 |
else:
|
81 |
-
stance
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
stance += (
|
98 |
-
claims_in_response[index]["claim"]
|
99 |
-
if (
|
100 |
-
verifications[index]["correction"] == "None"
|
101 |
-
or len(verifications[index]["correction"]) == 0
|
102 |
-
)
|
103 |
-
else (" " + verifications[index]["correction"])
|
104 |
-
)
|
105 |
claim_info[key]["stances"] = [stance]
|
106 |
for j in range(len(claim_info[key]["evidence_list"])):
|
107 |
claim_info[key]["evidence_list"][j]["stance"] = stance
|
@@ -114,8 +91,6 @@ class FactoolClaimExaminer(StandardTaskSolver):
|
|
114 |
with open(self.path_save_stance, "w") as outfile:
|
115 |
outfile.write(json_object)
|
116 |
|
117 |
-
# print(claim_info)
|
118 |
-
|
119 |
state.set(self.output_name, claim_info)
|
120 |
return True, state
|
121 |
|
@@ -123,8 +98,10 @@ class FactoolClaimExaminer(StandardTaskSolver):
|
|
123 |
messages_list = [
|
124 |
[
|
125 |
{"role": "system", "content": self.verification_prompt['system']},
|
126 |
-
{"role": "user", "content": self.verification_prompt['user'].format(
|
127 |
-
|
|
|
|
|
128 |
]
|
129 |
for claim, evidence in claims_with_evidences.items()
|
130 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
|
3 |
+
from .factool_utils.chat_api import OpenAIChat
|
4 |
+
from .factool_utils.prompt import VERIFICATION_PROMPT
|
5 |
|
6 |
+
from openfactcheck.core.state import FactCheckerState
|
7 |
+
from openfactcheck.core.solver import StandardTaskSolver, Solver
|
8 |
+
|
9 |
+
@Solver.register("factool_claim_examiner", "evidences", "claim_info")
|
|
|
|
|
|
|
|
|
|
|
10 |
class FactoolClaimExaminer(StandardTaskSolver):
|
11 |
+
"""
|
12 |
+
A solver to examine the claims in a response.
|
13 |
+
"""
|
14 |
def __init__(self, args):
|
15 |
super().__init__(args)
|
16 |
+
self.model_name = self.global_config.get("model_name", "gpt-4o")
|
17 |
self.path_save_stance = args.get("path_save_stance", "evidence_stance.json")
|
18 |
self.verifications = None
|
19 |
+
self.gpt = OpenAIChat(self.model_name)
|
20 |
+
self.verification_prompt = VERIFICATION_PROMPT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# async def coro (self, factool_instance, claims_in_response, evidences):
|
23 |
# self.verifications = await factool_instance.pipelines["kbqa_online"]._verification(claims_in_response, evidences)
|
|
|
57 |
# print(f'Verification for claim {key}: {verifications[index]}\n')
|
58 |
# print(f'Verification for claim {key}: Type = {type(verifications[index])}\n')
|
59 |
stance = ""
|
60 |
+
index = 0 # Ensure the 'index' variable is defined somewhere appropriate in your context
|
61 |
+
|
62 |
+
# Check if verifications at the current index is None or 'None'
|
63 |
+
if verifications[index] is None or verifications[index] == "None":
|
64 |
stance = claims_in_response[index]["claim"]
|
65 |
else:
|
66 |
+
# Initialize stance with error or empty string
|
67 |
+
error = verifications[index].get("error", "")
|
68 |
+
if error and error != "None":
|
69 |
+
stance = error + " "
|
70 |
+
|
71 |
+
# Append reasoning if it exists and is not 'None'
|
72 |
+
reasoning = verifications[index].get("reasoning", "")
|
73 |
+
if reasoning and reasoning != "None":
|
74 |
+
stance += reasoning
|
75 |
+
|
76 |
+
# Append claim or correction if available
|
77 |
+
correction = verifications[index].get("correction", "")
|
78 |
+
if correction and correction != "None":
|
79 |
+
stance += " " + correction
|
80 |
+
else:
|
81 |
+
stance += claims_in_response[index]["claim"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
claim_info[key]["stances"] = [stance]
|
83 |
for j in range(len(claim_info[key]["evidence_list"])):
|
84 |
claim_info[key]["evidence_list"][j]["stance"] = stance
|
|
|
91 |
with open(self.path_save_stance, "w") as outfile:
|
92 |
outfile.write(json_object)
|
93 |
|
|
|
|
|
94 |
state.set(self.output_name, claim_info)
|
95 |
return True, state
|
96 |
|
|
|
98 |
messages_list = [
|
99 |
[
|
100 |
{"role": "system", "content": self.verification_prompt['system']},
|
101 |
+
{"role": "user", "content": self.verification_prompt['user'].format(
|
102 |
+
claim=claim,
|
103 |
+
evidence=str([e[1] for e in evidence if isinstance(e, (list, tuple)) and len(e) > 1])
|
104 |
+
)}
|
105 |
]
|
106 |
for claim, evidence in claims_with_evidences.items()
|
107 |
]
|
src/openfactcheck/solvers/factool/factool_decontextualizer.py
CHANGED
@@ -1,38 +1,21 @@
|
|
1 |
-
from core import register_solver, StandardTaskSolver, FactCheckerState
|
2 |
-
import asyncio
|
3 |
-
import nest_asyncio
|
4 |
-
from factool import Factool
|
5 |
-
from .ftool_utils.chat_api import OpenAIChat
|
6 |
-
import yaml
|
7 |
import os
|
8 |
-
|
9 |
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
|
12 |
-
#
|
13 |
-
# Factool Decontextualizer
|
14 |
-
#
|
15 |
-
# Notes:
|
16 |
-
# - This solver is used to extract claims from a response.
|
17 |
-
# - The response should be a string.
|
18 |
-
#
|
19 |
-
##
|
20 |
-
@register_solver("factool_decontextualizer", "response", "claims")
|
21 |
class FactoolDecontextualizer(StandardTaskSolver):
|
|
|
|
|
|
|
22 |
def __init__(self, args):
|
23 |
super().__init__(args)
|
24 |
-
self.gpt_model = self.global_config.get("llm_in_use", "gpt-
|
25 |
self.gpt = OpenAIChat(self.gpt_model)
|
26 |
-
self.claim_prompt =
|
27 |
-
open(
|
28 |
-
os.path.join(
|
29 |
-
os.path.dirname(os.path.abspath(__file__)),
|
30 |
-
"ftool_utils/prompts.yaml",
|
31 |
-
),
|
32 |
-
"r",
|
33 |
-
),
|
34 |
-
yaml.FullLoader,
|
35 |
-
)["claim_extraction"]
|
36 |
|
37 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
38 |
response = state.get(self.input_name)
|
@@ -55,4 +38,4 @@ class FactoolDecontextualizer(StandardTaskSolver):
|
|
55 |
]
|
56 |
for response in responses
|
57 |
]
|
58 |
-
return self.gpt.run(messages_list,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import yaml
|
3 |
|
4 |
+
from .factool_utils.chat_api import OpenAIChat
|
5 |
+
from .factool_utils.prompt import CLAIM_EXTRACTION_PROMPT
|
6 |
+
from openfactcheck.core.solver import StandardTaskSolver, Solver
|
7 |
+
from openfactcheck.core.state import FactCheckerState
|
8 |
|
9 |
+
@Solver.register("factool_decontextualizer", "response", "claims")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
class FactoolDecontextualizer(StandardTaskSolver):
|
11 |
+
"""
|
12 |
+
A solver to extract claims from a response.
|
13 |
+
"""
|
14 |
def __init__(self, args):
|
15 |
super().__init__(args)
|
16 |
+
self.gpt_model = self.global_config.get("llm_in_use", "gpt-4o")
|
17 |
self.gpt = OpenAIChat(self.gpt_model)
|
18 |
+
self.claim_prompt = CLAIM_EXTRACTION_PROMPT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
def __call__(self, state: FactCheckerState, *args, **kwargs):
|
21 |
response = state.get(self.input_name)
|
|
|
38 |
]
|
39 |
for response in responses
|
40 |
]
|
41 |
+
return self.gpt.run(messages_list, list)
|
src/openfactcheck/solvers/factool/factool_evidence_retriever.py
CHANGED
@@ -1,42 +1,28 @@
|
|
1 |
-
from core import register_solver, StandardTaskSolver, FactCheckerState
|
2 |
-
from typing import List, Dict, Any
|
3 |
import json
|
4 |
-
from
|
5 |
-
from .ftool_utils.search_api import GoogleSerperAPIWrapper
|
6 |
-
import yaml
|
7 |
-
import os
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
#
|
18 |
-
##
|
19 |
-
@register_solver("factool_evidence_retriever", "claims", "evidences")
|
20 |
class FactoolEvidenceRetriever(StandardTaskSolver):
|
|
|
|
|
|
|
21 |
def __init__(self, args):
|
22 |
super().__init__(args)
|
23 |
-
self.gpt_model = self.global_config.get("llm_in_use", "gpt-
|
24 |
self.gpt = OpenAIChat(self.gpt_model)
|
|
|
25 |
self.path_save_evidence = args.get("path_save_evidence", "evidence.json")
|
26 |
-
# self.path_save_evidence = args["path_save_evidence"] if "path_save_evidence" in args else "evidence.json"
|
27 |
self.queries = None
|
28 |
self.search_outputs_for_claims = None
|
29 |
|
30 |
-
self.query_prompt =
|
31 |
-
open(
|
32 |
-
os.path.join(
|
33 |
-
os.path.dirname(os.path.abspath(__file__)),
|
34 |
-
"ftool_utils/prompts.yaml",
|
35 |
-
),
|
36 |
-
"r",
|
37 |
-
),
|
38 |
-
yaml.FullLoader,
|
39 |
-
)["query_generation"]
|
40 |
|
41 |
self.search_engine = GoogleSerperAPIWrapper(snippet_cnt=10)
|
42 |
|
@@ -52,9 +38,9 @@ class FactoolEvidenceRetriever(StandardTaskSolver):
|
|
52 |
search_outputs_for_claims = self.search_engine.run(queries)
|
53 |
|
54 |
|
55 |
-
evidences:
|
56 |
for i, claim in enumerate(claims):
|
57 |
-
evidence_list:
|
58 |
for j, search_outputs_for_claim in enumerate(
|
59 |
search_outputs_for_claims[i]
|
60 |
):
|
@@ -81,8 +67,6 @@ class FactoolEvidenceRetriever(StandardTaskSolver):
|
|
81 |
with open(self.path_save_evidence, "w") as outfile:
|
82 |
outfile.write(json_object)
|
83 |
|
84 |
-
# print(evidences)
|
85 |
-
|
86 |
state.set(self.output_name, evidences)
|
87 |
return True, state
|
88 |
|
@@ -97,4 +81,4 @@ class FactoolEvidenceRetriever(StandardTaskSolver):
|
|
97 |
]
|
98 |
for claim in claims
|
99 |
]
|
100 |
-
return self.gpt.run(messages_list,
|
|
|
|
|
|
|
1 |
import json
|
2 |
+
from typing import Any
|
|
|
|
|
|
|
3 |
|
4 |
+
from .factool_utils.chat_api import OpenAIChat
|
5 |
+
from .factool_utils.search_api import GoogleSerperAPIWrapper
|
6 |
+
from .factool_utils.prompt import QUERY_GENERATION_PROMPT
|
7 |
+
|
8 |
+
from openfactcheck.core.state import FactCheckerState
|
9 |
+
from openfactcheck.core.solver import StandardTaskSolver, Solver
|
10 |
+
|
11 |
+
@Solver.register("factool_evidence_retriever", "claims", "evidences")
|
|
|
|
|
|
|
12 |
class FactoolEvidenceRetriever(StandardTaskSolver):
|
13 |
+
"""
|
14 |
+
A solver to retrieve evidences for a list of evidence. (online content + its sources) for a list of claims.
|
15 |
+
"""
|
16 |
def __init__(self, args):
|
17 |
super().__init__(args)
|
18 |
+
self.gpt_model = self.global_config.get("llm_in_use", "gpt-4o")
|
19 |
self.gpt = OpenAIChat(self.gpt_model)
|
20 |
+
|
21 |
self.path_save_evidence = args.get("path_save_evidence", "evidence.json")
|
|
|
22 |
self.queries = None
|
23 |
self.search_outputs_for_claims = None
|
24 |
|
25 |
+
self.query_prompt = QUERY_GENERATION_PROMPT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
self.search_engine = GoogleSerperAPIWrapper(snippet_cnt=10)
|
28 |
|
|
|
38 |
search_outputs_for_claims = self.search_engine.run(queries)
|
39 |
|
40 |
|
41 |
+
evidences: dict[str, dict[str, Any]] = {}
|
42 |
for i, claim in enumerate(claims):
|
43 |
+
evidence_list: list[dict] = []
|
44 |
for j, search_outputs_for_claim in enumerate(
|
45 |
search_outputs_for_claims[i]
|
46 |
):
|
|
|
67 |
with open(self.path_save_evidence, "w") as outfile:
|
68 |
outfile.write(json_object)
|
69 |
|
|
|
|
|
70 |
state.set(self.output_name, evidences)
|
71 |
return True, state
|
72 |
|
|
|
81 |
]
|
82 |
for claim in claims
|
83 |
]
|
84 |
+
return self.gpt.run(messages_list, list)
|
src/openfactcheck/solvers/factool/factool_example.py
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from pipeline import Pipeline
|
3 |
-
from argparse import Namespace
|
4 |
-
|
5 |
-
# Base directory where the script is located
|
6 |
-
base_dir = os.path.abspath(os.path.dirname(__file__))
|
7 |
-
|
8 |
-
args = Namespace(
|
9 |
-
user_src=os.path.join(base_dir),
|
10 |
-
config=os.path.join(base_dir, "../../config/factool_config.yaml"),
|
11 |
-
output=os.path.join(base_dir, "../../../output")
|
12 |
-
)
|
13 |
-
|
14 |
-
p = Pipeline(args)
|
15 |
-
question = "Who is Alan Turing?"
|
16 |
-
response = "Alan Turing was a British mathematician, logician, cryptanalyst, and computer scientist. He was highly influential in the development of theoretical computer science."
|
17 |
-
|
18 |
-
print(p(question=question, response=response))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/openfactcheck/solvers/factool/factool_post_editor.py
CHANGED
@@ -1,15 +1,12 @@
|
|
1 |
-
from core import
|
|
|
2 |
|
3 |
-
|
4 |
-
#
|
5 |
-
# Factool Data Post-Editor
|
6 |
-
#
|
7 |
-
# Notes:
|
8 |
-
# Factool response post-processor. Used to presents the results in human-readable format and to save the analysis in a JSON file.
|
9 |
-
#
|
10 |
-
##
|
11 |
-
@register_solver("factool_post_editor", "claim_info", "claim_info")
|
12 |
class FactoolPostEditor(StandardTaskSolver):
|
|
|
|
|
|
|
|
|
13 |
def __init__(self, args):
|
14 |
super().__init__(args)
|
15 |
|
|
|
1 |
+
from openfactcheck.core.state import FactCheckerState
|
2 |
+
from openfactcheck.core.solver import StandardTaskSolver, Solver
|
3 |
|
4 |
+
@Solver.register("factool_post_editor", "claim_info", "claim_info")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
class FactoolPostEditor(StandardTaskSolver):
|
6 |
+
"""
|
7 |
+
A solver to post-process the results of the Factool model.
|
8 |
+
Used to presents the results in human-readable format and to save the analysis in a JSON file.
|
9 |
+
"""
|
10 |
def __init__(self, args):
|
11 |
super().__init__(args)
|
12 |
|
src/openfactcheck/solvers/factool/factool_utils/__init__.py
ADDED
File without changes
|
src/openfactcheck/solvers/factool/{ftool_utils → factool_utils}/chat_api.py
RENAMED
@@ -15,12 +15,6 @@ import openai
|
|
15 |
from openai import OpenAI, AsyncOpenAI
|
16 |
import re
|
17 |
|
18 |
-
|
19 |
-
# from factool.env_config import factool_env_config
|
20 |
-
|
21 |
-
# env
|
22 |
-
# openai.api_key = factool_env_config.openai_api_key
|
23 |
-
|
24 |
class OpenAIChat():
|
25 |
def __init__(
|
26 |
self,
|
@@ -97,11 +91,9 @@ class OpenAIChat():
|
|
97 |
'''
|
98 |
return None
|
99 |
|
100 |
-
async def dispatch_openai_requests(
|
101 |
-
|
102 |
-
|
103 |
-
) -> list[str]:
|
104 |
-
"""Dispatches requests to OpenAI API asynchronously.
|
105 |
|
106 |
Args:
|
107 |
messages_list: List of messages to be sent to OpenAI ChatCompletion API.
|
@@ -109,7 +101,7 @@ class OpenAIChat():
|
|
109 |
List of responses from OpenAI API.
|
110 |
"""
|
111 |
async def _request_with_retry(messages, retry=3):
|
112 |
-
for
|
113 |
try:
|
114 |
response = await self.client.chat.completions.create(
|
115 |
model=self.config['model_name'],
|
@@ -119,39 +111,22 @@ class OpenAIChat():
|
|
119 |
top_p=self.config['top_p']
|
120 |
)
|
121 |
return response
|
122 |
-
except openai.RateLimitError:
|
123 |
-
await asyncio.sleep(
|
124 |
-
except openai.Timeout:
|
125 |
-
await asyncio.sleep(
|
126 |
-
except
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
# except openai.error.RateLimitError:
|
132 |
-
# print('Rate limit error, waiting for 40 second...')
|
133 |
-
# await asyncio.sleep(40)
|
134 |
-
# except openai.error.APIError:
|
135 |
-
# print('API error, waiting for 1 second...')
|
136 |
-
# await asyncio.sleep(1)
|
137 |
-
# except openai.error.Timeout:
|
138 |
-
# print('Timeout error, waiting for 1 second...')
|
139 |
-
# await asyncio.sleep(1)
|
140 |
-
# except openai.error.ServiceUnavailableError:
|
141 |
-
# print('Service unavailable error, waiting for 3 second...')
|
142 |
-
# await asyncio.sleep(3)
|
143 |
-
# except openai.error.APIConnectionError:
|
144 |
-
# print('API Connection error, waiting for 3 second...')
|
145 |
-
# await asyncio.sleep(3)
|
146 |
-
|
147 |
-
return None
|
148 |
|
149 |
async_responses = [
|
150 |
_request_with_retry(messages)
|
151 |
for messages in messages_list
|
152 |
]
|
153 |
|
154 |
-
return await asyncio.gather(*async_responses)
|
155 |
|
156 |
def run(self, messages_list, expected_type):
|
157 |
retry = 1
|
@@ -159,7 +134,6 @@ class OpenAIChat():
|
|
159 |
messages_list_cur_index = [i for i in range(len(messages_list))]
|
160 |
|
161 |
while retry > 0 and len(messages_list_cur_index) > 0:
|
162 |
-
print(f'{retry} retry left...')
|
163 |
messages_list_cur = [messages_list[i] for i in messages_list_cur_index]
|
164 |
|
165 |
predictions = asyncio.run(self.dispatch_openai_requests(
|
|
|
15 |
from openai import OpenAI, AsyncOpenAI
|
16 |
import re
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
class OpenAIChat():
|
19 |
def __init__(
|
20 |
self,
|
|
|
91 |
'''
|
92 |
return None
|
93 |
|
94 |
+
async def dispatch_openai_requests(self, messages_list,) -> list[str]:
|
95 |
+
"""
|
96 |
+
Dispatches requests to OpenAI API asynchronously.
|
|
|
|
|
97 |
|
98 |
Args:
|
99 |
messages_list: List of messages to be sent to OpenAI ChatCompletion API.
|
|
|
101 |
List of responses from OpenAI API.
|
102 |
"""
|
103 |
async def _request_with_retry(messages, retry=3):
|
104 |
+
for attempt in range(retry):
|
105 |
try:
|
106 |
response = await self.client.chat.completions.create(
|
107 |
model=self.config['model_name'],
|
|
|
111 |
top_p=self.config['top_p']
|
112 |
)
|
113 |
return response
|
114 |
+
except openai.RateLimitError as e:
|
115 |
+
await asyncio.sleep((2 ** attempt) * 0.5) # exponential backoff
|
116 |
+
except (openai.Timeout, openai.APIError) as e:
|
117 |
+
await asyncio.sleep((2 ** attempt) * 0.5) # exponential backoff
|
118 |
+
except Exception as e:
|
119 |
+
# Log unexpected exception for further investigation
|
120 |
+
await asyncio.sleep((2 ** attempt) * 0.5) # fallback in case of unknown errors
|
121 |
+
|
122 |
+
raise RuntimeError("All retries failed for OpenAI API request")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
async_responses = [
|
125 |
_request_with_retry(messages)
|
126 |
for messages in messages_list
|
127 |
]
|
128 |
|
129 |
+
return await asyncio.gather(*async_responses, return_exceptions=True)
|
130 |
|
131 |
def run(self, messages_list, expected_type):
|
132 |
retry = 1
|
|
|
134 |
messages_list_cur_index = [i for i in range(len(messages_list))]
|
135 |
|
136 |
while retry > 0 and len(messages_list_cur_index) > 0:
|
|
|
137 |
messages_list_cur = [messages_list[i] for i in messages_list_cur_index]
|
138 |
|
139 |
predictions = asyncio.run(self.dispatch_openai_requests(
|
src/openfactcheck/solvers/factool/factool_utils/prompt.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""All prompts used for fact-checking subtasks prompting."""
|
2 |
+
|
3 |
+
CLAIM_EXTRACTION_PROMPT = {
|
4 |
+
"system": "Please provide the claim you would like to fact-check.",
|
5 |
+
"user": """You are given a piece of text that includes knowledge claims. A claim is a statement that asserts something as true or false, which can be verified by humans. Your task is to accurately identify and extract every claim stated in the provided text. Then, resolve any coreference (pronouns or other referring expressions) in the claim for clarity. Each claim should be concise (less than 15 words) and self-contained.
|
6 |
+
Your response MUST be a list of dictionaries. Each dictionary should contains the key "claim", which correspond to the extracted claim (with all coreferences resolved).
|
7 |
+
You MUST only respond in the format as described below. DO NOT RESPOND WITH ANYTHING ELSE. ADDING ANY OTHER EXTRA NOTES THAT VIOLATE THE RESPONSE FORMAT IS BANNED. START YOUR RESPONSE WITH '['.
|
8 |
+
[response format]:
|
9 |
+
[
|
10 |
+
{{
|
11 |
+
"claim": "Ensure that the claim is fewer than 15 words and conveys a complete idea. Resolve any coreference (pronouns or other referring expressions) in the claim for clarity",
|
12 |
+
}},
|
13 |
+
...
|
14 |
+
]
|
15 |
+
|
16 |
+
Here are two examples:
|
17 |
+
[text]: Tomas Berdych defeated Gael Monfis 6-1, 6-4 on Saturday. The sixth-seed reaches Monte Carlo Masters final for the first time . Berdych will face either Rafael Nadal or Novak Djokovic in the final.
|
18 |
+
[response]: [{{"claim": "Tomas Berdych defeated Gael Monfis 6-1, 6-4"}}, {{"claim": "Tomas Berdych defeated Gael Monfis 6-1, 6-4 on Saturday"}}, {{"claim": "Tomas Berdych reaches Monte Carlo Masters final"}}, {{"claim": "Tomas Berdych is the sixth-seed"}}, {{"claim": "Tomas Berdych reaches Monte Carlo Masters final for the first time"}}, {{"claim": "Berdych will face either Rafael Nadal or Novak Djokovic"}}, {{"claim": "Berdych will face either Rafael Nadal or Novak Djokovic in the final"}}]
|
19 |
+
|
20 |
+
[text]: Tinder only displays the last 34 photos - but users can easily see more. Firm also said it had improved its mutual friends feature.
|
21 |
+
[response]: [{{"claim": "Tinder only displays the last photos"}}, {{"claim": "Tinder only displays the last 34 photos"}}, {{"claim": "Tinder users can easily see more photos"}}, {{"claim": "Tinder said it had improved its feature"}}, {{"claim": "Tinder said it had improved its mutual friends feature"}}]
|
22 |
+
|
23 |
+
Now complete the following,ONLY RESPONSE IN A LIST FORMAT, NO OTHER WORDS!!!:
|
24 |
+
[text]: {input}
|
25 |
+
[response]:
|
26 |
+
"""
|
27 |
+
}
|
28 |
+
|
29 |
+
QUERY_GENERATION_PROMPT = {
|
30 |
+
"system": "You are a query generator that generates effective and concise search engine queries to verify a given claim. You only response in a python list format(NO OTHER WORDS!)",
|
31 |
+
"user": """You are a query generator designed to help users verify a given claim using search engines. Your primary task is to generate a Python list of two effective and skeptical search engine queries. These queries should assist users in critically evaluating the factuality of a provided claim using search engines.
|
32 |
+
You should only respond in format as described below (a Python list of queries). PLEASE STRICTLY FOLLOW THE FORMAT. DO NOT RETURN ANYTHING ELSE. START YOUR RESPONSE WITH '['.
|
33 |
+
[response format]: ['query1', 'query2']
|
34 |
+
|
35 |
+
Here are three examples:
|
36 |
+
claim: The CEO of twitter is Bill Gates.
|
37 |
+
response: ["Who is the CEO of twitter?", "CEO Twitter"]
|
38 |
+
|
39 |
+
claim: Michael Phelps is the most decorated Olympian of all time.
|
40 |
+
response: ["Who is the most decorated Olympian of all time?", "Michael Phelps"]
|
41 |
+
|
42 |
+
claim: ChatGPT is created by Google.
|
43 |
+
response: ["Who created ChatGPT?", "ChatGPT"]
|
44 |
+
|
45 |
+
Now complete the following(ONLY RESPONSE IN A LIST FORMAT, DO NOT RETURN OTHER WORDS!!! START YOUR RESPONSE WITH '[' AND END WITH ']'):
|
46 |
+
claim: {input}
|
47 |
+
response:
|
48 |
+
"""
|
49 |
+
}
|
50 |
+
|
51 |
+
VERIFICATION_PROMPT = {
|
52 |
+
"system": "You are a brilliant assistant.",
|
53 |
+
"user": """You are given a piece of text. Your task is to identify whether there are any factual errors within the text.
|
54 |
+
When you are judging the factuality of the given text, you could reference the provided evidences if needed. The provided evidences may be helpful. Some evidences may contradict to each other. You must be careful when using the evidences to judge the factuality of the given text.
|
55 |
+
The response should be a dictionary with three keys - "reasoning", "factuality", "error", and "correction", which correspond to the reasoning, whether the given text is factual or not (Boolean - True or False), the factual error present in the text, and the corrected text.
|
56 |
+
The following is the given text
|
57 |
+
[text]: {claim}
|
58 |
+
The following is the provided evidences
|
59 |
+
[evidences]: {evidence}
|
60 |
+
You should only respond in format as described below. DO NOT RETURN ANYTHING ELSE. START YOUR RESPONSE WITH '{{'.
|
61 |
+
[response format]:
|
62 |
+
{{
|
63 |
+
"reasoning": "Why is the given text factual or non-factual? Be careful when you said something is non-factual. When you said something is non-factual, you must provide multiple evidences to support your decision.",
|
64 |
+
"error": "None if the text is factual; otherwise, describe the error.",
|
65 |
+
"correction": "The corrected text if there is an error.",
|
66 |
+
"factuality": True if the given text is factual, False otherwise.
|
67 |
+
}}
|
68 |
+
"""
|
69 |
+
}
|
src/openfactcheck/solvers/factool/{ftool_utils → factool_utils}/search_api.py
RENAMED
File without changes
|
src/openfactcheck/solvers/factool/ftool_utils/prompts.yaml
DELETED
@@ -1,67 +0,0 @@
|
|
1 |
-
claim_extraction:
|
2 |
-
system: |-
|
3 |
-
You are a brilliant assistant.
|
4 |
-
user: |-
|
5 |
-
You are given a piece of text that includes knowledge claims. A claim is a statement that asserts something as true or false, which can be verified by humans. Your task is to accurately identify and extract every claim stated in the provided text. Then, resolve any coreference (pronouns or other referring expressions) in the claim for clarity. Each claim should be concise (less than 15 words) and self-contained.
|
6 |
-
Your response MUST be a list of dictionaries. Each dictionary should contains the key "claim", which correspond to the extracted claim (with all coreferences resolved).
|
7 |
-
You MUST only respond in the format as described below. DO NOT RESPOND WITH ANYTHING ELSE. ADDING ANY OTHER EXTRA NOTES THAT VIOLATE THE RESPONSE FORMAT IS BANNED. START YOUR RESPONSE WITH '['.
|
8 |
-
[response format]:
|
9 |
-
[
|
10 |
-
{{
|
11 |
-
"claim": "Ensure that the claim is fewer than 15 words and conveys a complete idea. Resolve any coreference (pronouns or other referring expressions) in the claim for clarity",
|
12 |
-
}},
|
13 |
-
...
|
14 |
-
]
|
15 |
-
|
16 |
-
Here are two examples:
|
17 |
-
[text]: Tomas Berdych defeated Gael Monfis 6-1, 6-4 on Saturday. The sixth-seed reaches Monte Carlo Masters final for the first time . Berdych will face either Rafael Nadal or Novak Djokovic in the final.
|
18 |
-
[response]: [{{"claim": "Tomas Berdych defeated Gael Monfis 6-1, 6-4"}}, {{"claim": "Tomas Berdych defeated Gael Monfis 6-1, 6-4 on Saturday"}}, {{"claim": "Tomas Berdych reaches Monte Carlo Masters final"}}, {{"claim": "Tomas Berdych is the sixth-seed"}}, {{"claim": "Tomas Berdych reaches Monte Carlo Masters final for the first time"}}, {{"claim": "Berdych will face either Rafael Nadal or Novak Djokovic"}}, {{"claim": "Berdych will face either Rafael Nadal or Novak Djokovic in the final"}}]
|
19 |
-
|
20 |
-
[text]: Tinder only displays the last 34 photos - but users can easily see more. Firm also said it had improved its mutual friends feature.
|
21 |
-
[response]: [{{"claim": "Tinder only displays the last photos"}}, {{"claim": "Tinder only displays the last 34 photos"}}, {{"claim": "Tinder users can easily see more photos"}}, {{"claim": "Tinder said it had improved its feature"}}, {{"claim": "Tinder said it had improved its mutual friends feature"}}]
|
22 |
-
|
23 |
-
Now complete the following,ONLY RESPONSE IN A LIST FORMAT, NO OTHER WORDS!!!:
|
24 |
-
[text]: {input}
|
25 |
-
[response]:
|
26 |
-
|
27 |
-
query_generation:
|
28 |
-
system: |-
|
29 |
-
You are a query generator that generates effective and concise search engine queries to verify a given claim. You only response in a python list format(NO OTHER WORDS!)
|
30 |
-
user: |-
|
31 |
-
You are a query generator designed to help users verify a given claim using search engines. Your primary task is to generate a Python list of two effective and skeptical search engine queries. These queries should assist users in critically evaluating the factuality of a provided claim using search engines.
|
32 |
-
You should only respond in format as described below (a Python list of queries). PLEASE STRICTLY FOLLOW THE FORMAT. DO NOT RETURN ANYTHING ELSE. START YOUR RESPONSE WITH '['.
|
33 |
-
[response format]: ['query1', 'query2']
|
34 |
-
|
35 |
-
Here are three examples:
|
36 |
-
claim: The CEO of twitter is Bill Gates.
|
37 |
-
response: ["Who is the CEO of twitter?", "CEO Twitter"]
|
38 |
-
|
39 |
-
claim: Michael Phelps is the most decorated Olympian of all time.
|
40 |
-
response: ["Who is the most decorated Olympian of all time?", "Michael Phelps"]
|
41 |
-
|
42 |
-
claim: ChatGPT is created by Google.
|
43 |
-
response: ["Who created ChatGPT?", "ChatGPT"]
|
44 |
-
|
45 |
-
Now complete the following(ONLY RESPONSE IN A LIST FORMAT, DO NOT RETURN OTHER WORDS!!! START YOUR RESPONSE WITH '[' AND END WITH ']'):
|
46 |
-
claim: {input}
|
47 |
-
response:
|
48 |
-
|
49 |
-
verification:
|
50 |
-
system: |-
|
51 |
-
You are a brilliant assistant.
|
52 |
-
user: |-
|
53 |
-
You are given a piece of text. Your task is to identify whether there are any factual errors within the text.
|
54 |
-
When you are judging the factuality of the given text, you could reference the provided evidences if needed. The provided evidences may be helpful. Some evidences may contradict to each other. You must be careful when using the evidences to judge the factuality of the given text.
|
55 |
-
The response should be a dictionary with three keys - "reasoning", "factuality", "error", and "correction", which correspond to the reasoning, whether the given text is factual or not (Boolean - True or False), the factual error present in the text, and the corrected text.
|
56 |
-
The following is the given text
|
57 |
-
[text]: {claim}
|
58 |
-
The following is the provided evidences
|
59 |
-
[evidences]: {evidence}
|
60 |
-
You should only respond in format as described below. DO NOT RETURN ANYTHING ELSE. START YOUR RESPONSE WITH '{{'.
|
61 |
-
[response format]:
|
62 |
-
{{
|
63 |
-
"reasoning": "Why is the given text factual or non-factual? Be careful when you said something is non-factual. When you said something is non-factual, you must provide multiple evidences to support your decision.",
|
64 |
-
"error": "None if the text is factual; otherwise, describe the error.",
|
65 |
-
"correction": "The corrected text if there is an error.",
|
66 |
-
"factuality": True if the given text is factual, False otherwise.
|
67 |
-
}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/openfactcheck/solvers/webservice/__init__.py
ADDED
File without changes
|
src/openfactcheck/solvers/webservice/factool_utils/chat_api.py
CHANGED
@@ -151,7 +151,7 @@ class OpenAIChat():
|
|
151 |
for messages in messages_list
|
152 |
]
|
153 |
|
154 |
-
return await asyncio.gather(*async_responses)
|
155 |
|
156 |
def run(self, messages_list, expected_type):
|
157 |
retry = 1
|
@@ -159,7 +159,6 @@ class OpenAIChat():
|
|
159 |
messages_list_cur_index = [i for i in range(len(messages_list))]
|
160 |
|
161 |
while retry > 0 and len(messages_list_cur_index) > 0:
|
162 |
-
print(f'{retry} retry left...')
|
163 |
messages_list_cur = [messages_list[i] for i in messages_list_cur_index]
|
164 |
|
165 |
predictions = asyncio.run(self.dispatch_openai_requests(
|
|
|
151 |
for messages in messages_list
|
152 |
]
|
153 |
|
154 |
+
return await asyncio.gather(*async_responses, return_exceptions=True)
|
155 |
|
156 |
def run(self, messages_list, expected_type):
|
157 |
retry = 1
|
|
|
159 |
messages_list_cur_index = [i for i in range(len(messages_list))]
|
160 |
|
161 |
while retry > 0 and len(messages_list_cur_index) > 0:
|
|
|
162 |
messages_list_cur = [messages_list[i] for i in messages_list_cur_index]
|
163 |
|
164 |
predictions = asyncio.run(self.dispatch_openai_requests(
|
src/openfactcheck/solvers/webservice/ftool_cp.py
CHANGED
@@ -1,6 +1,3 @@
|
|
1 |
-
import os
|
2 |
-
import yaml
|
3 |
-
|
4 |
from openfactcheck.core.state import FactCheckerState
|
5 |
from openfactcheck.core.solver import StandardTaskSolver, Solver
|
6 |
|
|
|
|
|
|
|
|
|
1 |
from openfactcheck.core.state import FactCheckerState
|
2 |
from openfactcheck.core.solver import StandardTaskSolver, Solver
|
3 |
|