rd_table_bench / 1_πŸ”_Explorer.py
raunakdoesdev's picture
simplify
727cda5
from huggingface_hub import hf_hub_download
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import os
import zipfile
import shutil
st.set_page_config(layout="wide")
with st.spinner("Downloading dataset"):
results = hf_hub_download(
repo_id="reducto/rd-tablebench",
filename="rd-tablebench.zip",
repo_type="dataset",
)
def unzip_dataset():
if not os.path.exists("unzipped_dataset"):
os.makedirs("unzipped_dataset")
with st.spinner("Unzipping dataset"):
with zipfile.ZipFile(results, "r") as zip_ref:
zip_ref.extractall("unzipped_dataset")
return "unzipped_dataset/rd-tablebench"
if st.button("Redo Unzip"):
if os.path.exists("unzipped_dataset"):
shutil.rmtree("unzipped_dataset")
st.rerun()
dataset = unzip_dataset()
results = f"{dataset}/providers/scores.csv"
assert os.path.exists(results)
st.html("""
<style>
table {
font-family: arial, sans-serif;
border-collapse: collapse;
white-space: pre;
}
td, th {
border: 1px solid #dddddd;
text-align: left;
padding: 8px;
font-weight: normal;
}
</style>
""")
df = pd.read_csv(results)
if "current_index" not in st.session_state:
st.session_state.current_index = 0
col1, col2, col3 = st.columns([2, 5, 2])
with col1:
st.html("<br/>")
if st.button("⬅️ Previous", use_container_width=True):
if st.session_state.current_index > 0:
st.session_state.current_index -= 1
st.rerun()
# Search box and Go button in col2
with col2:
index_input = st.number_input(
"Index",
label_visibility="hidden",
min_value=0,
max_value=len(df) - 1,
value=st.session_state.current_index,
step=1,
)
if st.button("Go", use_container_width=True):
st.session_state.current_index = int(index_input)
st.rerun()
# Next button in col3
with col3:
st.html("<br/>")
if st.button("Next ➑️", use_container_width=True):
if st.session_state.current_index < len(df) - 1:
st.session_state.current_index += 1
st.rerun()
col1, col2 = st.columns([1, 2])
providers = [
"reducto",
"azure",
"textract",
"gcloud",
"unstructured",
"gpt4o",
"chunkr",
]
with col1:
row = df.iloc[st.session_state.current_index]
# Extract scores
scores = [
row[f"{p}_score"] if row[f"{p}_score"] is not None else 0 for p in providers
]
fig, ax = plt.subplots(figsize=(6, 10))
bars = ax.barh(providers[::-1], scores[::-1])
# Customize plot
ax.set_title("Provider Scores Comparison")
ax.set_ylabel("Providers")
ax.set_xlabel("Scores")
ax.set_xlim(0, 1.1)
for bar in bars:
width = bar.get_width()
ax.text(
width,
bar.get_y() + bar.get_height() / 2.0,
f"{width:.3f}",
ha="left",
va="center",
)
plt.tight_layout()
st.pyplot(fig)
with col2:
image_path = f"{dataset}/_images/{row['pdf_path'].replace('.pdf', '.jpg')}"
st.image(image_path, use_column_width=True)
st.write(row)
st.subheader("Groundtruth")
st.html(f"{dataset}/groundtruth/{row['pdf_path'].replace('.pdf', '.html')}")
st.subheader("Provider Outputs")
for p in providers:
with st.expander(p):
provider_html = (
f"{dataset}/providers/{p}/{row['pdf_path'].replace('.pdf', '.html')}"
)
if os.path.exists(provider_html):
st.html(provider_html)
else:
st.error(f"{p} failed to produce a table output for this image")