|
import os |
|
import uuid |
|
import zipfile |
|
import pandas as pd |
|
import seaborn as sns |
|
import streamlit as st |
|
import matplotlib.pyplot as plt |
|
from importlib import resources as pkg_resources |
|
|
|
from openfactcheck.app.utils import metric_card |
|
from openfactcheck.factchecker.evaluate import FactCheckerEvaluator |
|
from openfactcheck.templates import factchecker as templates_dir |
|
|
|
|
|
claims_templates_path = str(pkg_resources.files(templates_dir) / "claims.jsonl") |
|
documents_templates_path = str(pkg_resources.files(templates_dir) / "documents.jsonl") |
|
|
|
def evaluate_factchecker(): |
|
""" |
|
This function creates a Streamlit app to evaluate a Factchecker. |
|
""" |
|
st.write("This is where you can evaluate the factuality of a FactChecker.") |
|
|
|
|
|
st.write("Download the benchmark evaluate the factuality of a FactChecker.") |
|
|
|
|
|
if os.path.exists(claims_templates_path) and os.path.exists(documents_templates_path): |
|
|
|
from io import BytesIO |
|
memory_file = BytesIO() |
|
with zipfile.ZipFile(memory_file, 'w') as zf: |
|
|
|
zip_path = os.path.basename(claims_templates_path) |
|
|
|
zf.write(claims_templates_path, arcname=zip_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
memory_file.seek(0) |
|
|
|
|
|
btn = st.download_button( |
|
label="Download", |
|
data=memory_file, |
|
file_name="openfactcheck_factchecker_benchmark.zip", |
|
mime="application/zip" |
|
) |
|
else: |
|
st.error("File not found.") |
|
|
|
|
|
st.write("Upload the FactChecker responses as a JSON file below to evaluate the factuality.") |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload", type=["csv"], label_visibility="collapsed") |
|
|
|
|
|
if uploaded_file is None: |
|
st.info("Please upload a CSV file.") |
|
return |
|
|
|
|
|
if uploaded_file.type != "text/csv": |
|
st.error("Invalid file format. Please upload a CSV file.") |
|
return |
|
|
|
|
|
uploaded_data = pd.read_csv(uploaded_file) |
|
|
|
def update_first_name(): |
|
st.session_state.first_name = st.session_state.input_first_name |
|
|
|
def update_last_name(): |
|
st.session_state.last_name = st.session_state.input_last_name |
|
|
|
def update_email(): |
|
st.session_state.email = st.session_state.input_email |
|
|
|
def update_organization(): |
|
st.session_state.organization = st.session_state.input_organization |
|
|
|
def update_factchecker(): |
|
st.session_state.factchecker = st.session_state.input_factchecker |
|
|
|
def update_include_in_leaderboard(): |
|
st.session_state.include_in_leaderboard = st.session_state.input_include_in_leaderboard |
|
|
|
|
|
st.write("Please provide the following information to be included in the leaderboard.") |
|
|
|
|
|
st.session_state.id = uuid.uuid4().hex |
|
st.text_input("First Name", key="input_first_name", on_change=update_first_name) |
|
st.text_input("Last Name", key="input_last_name", on_change=update_last_name) |
|
st.text_input("Email", key="input_email", on_change=update_email) |
|
st.text_input("FactChecker Name", key="input_factchecker", on_change=update_factchecker) |
|
st.text_input("Organization (Optional)", key="input_organization", on_change=update_organization) |
|
|
|
st.checkbox("Please check this box if you want your FactChecker to be included in the leaderboard.", |
|
key="input_include_in_leaderboard", |
|
on_change=update_include_in_leaderboard) |
|
|
|
if st.button("Evaluate FactChecker"): |
|
|
|
st.success("User information saved successfully.") |
|
|
|
|
|
with st.status("Evaluating factuality of the FactChecker...", expanded=True) as status: |
|
|
|
fce = FactCheckerEvaluator(input=uploaded_data, eval_type="claims") |
|
result = fce() |
|
status.update(label="FactChecker evaluated...", state="complete", expanded=False) |
|
|
|
|
|
st.write("### Evaluation report:") |
|
|
|
""" |
|
{ |
|
"True_as_positive": { |
|
"accuracy": 0.486, |
|
"precision": 0.71, |
|
"recall": 0.478, |
|
"F1": 0.571 |
|
}, |
|
"False_as_positive": { |
|
"accuracy": 0.486, |
|
"precision": 0.277, |
|
"recall": 0.506, |
|
"F1": 0.358 |
|
}, |
|
"total_time": 14430.0, |
|
"total_cost": 144.3, |
|
"num_samples": 1443 |
|
} |
|
""" |
|
|
|
col1, col2 = st.columns(2, gap="large") |
|
with col1: |
|
|
|
classes = ['True', 'False'] |
|
fig = plt.figure() |
|
sns.heatmap(fce.confusion_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=classes, yticklabels=classes) |
|
plt.ylabel('Actual Class') |
|
plt.xlabel('Predicted Class') |
|
st.pyplot(fig) |
|
with col2: |
|
|
|
accuracy = fce.results["True_as_positive"]["accuracy"] |
|
if accuracy > 0.75 and accuracy <= 1: |
|
|
|
metric_card(label="Accuracy", value=f"{accuracy:.2%}", background_color="#D4EDDA", border_left_color="#28A745") |
|
elif accuracy > 0.25 and accuracy <= 0.75: |
|
|
|
metric_card(label="Accuracy", value=f"{accuracy:.2%}", background_color="#FFF3CD", border_left_color="#FFC107") |
|
else: |
|
|
|
metric_card(label="Accuracy", value=f"{accuracy:.2%}", background_color="#F8D7DA", border_left_color="#DC3545") |
|
|
|
sub_col1, sub_col2, sub_col3 = st.columns(3) |
|
with sub_col1: |
|
metric_card(label="Total Time", value=fce.results["total_time"]) |
|
with sub_col2: |
|
metric_card(label="Total Cost", value=fce.results["total_cost"]) |
|
with sub_col3: |
|
metric_card(label="Number of Samples", value=fce.results["num_samples"]) |
|
|
|
st.text("Report:\n" + fce.classification_report) |
|
|
|
|
|
|