|
import streamlit as st |
|
from pathlib import Path |
|
from uuid import uuid4 |
|
import csv |
|
from datetime import datetime, timezone |
|
|
|
from huggingface_hub import CommitScheduler |
|
|
|
|
|
CSV_DATASET_DIR = Path("flagged_rows") |
|
CSV_DATASET_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
CSV_DATASET_PATH = CSV_DATASET_DIR / f"train-{uuid4()}.csv" |
|
|
|
wrote_header = False |
|
|
|
|
|
def write_header(writer): |
|
writer.writerow( |
|
[ |
|
"date", |
|
"grascii", |
|
"longhand", |
|
"incorrect_grascii", |
|
"incorrect_longhand", |
|
"incorrect_shorthand", |
|
"improperly_cropped", |
|
"extraneous_marks", |
|
] |
|
) |
|
global wrote_header |
|
wrote_header = True |
|
|
|
|
|
scheduler = CommitScheduler( |
|
repo_id=st.secrets.FEEDBACK_REPO, |
|
repo_type="dataset", |
|
folder_path=CSV_DATASET_DIR, |
|
path_in_repo="data", |
|
every=15, |
|
token=st.secrets.HF_TOKEN, |
|
) |
|
|
|
|
|
@st.dialog("Flag Results for Review", width="large") |
|
def report_dialog(data): |
|
st.write("Please select one or more reasons for flagging each row:") |
|
|
|
report_df = data |
|
report_df["3"] = False |
|
report_df["4"] = False |
|
report_df["5"] = False |
|
report_df["6"] = False |
|
report_df["7"] = False |
|
final_report = st.data_editor( |
|
report_df, |
|
hide_index=True, |
|
column_config={ |
|
"0": "Grascii", |
|
"1": "Longhand", |
|
"2": st.column_config.ImageColumn("Shorthand", width="medium"), |
|
"3": st.column_config.CheckboxColumn("Grascii is incorrect"), |
|
"4": st.column_config.CheckboxColumn("Longhand is incorrect"), |
|
"5": st.column_config.CheckboxColumn("Shorthand image is incorrect"), |
|
"6": st.column_config.CheckboxColumn( |
|
"Shorthand image is improperly cropped" |
|
), |
|
"7": st.column_config.CheckboxColumn( |
|
"Shorthand image contains extraneous marks" |
|
), |
|
}, |
|
disabled=["0", "1", "2"], |
|
use_container_width=True, |
|
) |
|
|
|
if st.button("Submit"): |
|
with scheduler.lock: |
|
with open(CSV_DATASET_PATH, "a", newline="") as f: |
|
writer = csv.writer(f, dialect="unix") |
|
|
|
def write_row(row): |
|
if not wrote_header: |
|
write_header(writer) |
|
if any( |
|
[ |
|
row.iloc[3], |
|
row.iloc[4], |
|
row.iloc[5], |
|
row.iloc[6], |
|
row.iloc[7], |
|
] |
|
): |
|
writer.writerow( |
|
[ |
|
datetime.now(timezone.utc).date(), |
|
row.iloc[0], |
|
row.iloc[1], |
|
1 if row.iloc[3] else 0, |
|
1 if row.iloc[4] else 0, |
|
1 if row.iloc[5] else 0, |
|
1 if row.iloc[6] else 0, |
|
1 if row.iloc[7] else 0, |
|
] |
|
) |
|
|
|
final_report.apply(write_row, axis=1) |
|
|
|
st.session_state["report_submitted"] = True |
|
st.rerun() |
|
|