|
import base64 |
|
from io import BytesIO |
|
import numpy as np |
|
import streamlit as st |
|
from PIL import Image |
|
import pandas as pd |
|
from datasets import load_dataset |
|
from grascii import GrasciiSearcher, InvalidGrascii, ReverseSearcher |
|
from report import report_dialog |
|
from vision import run_vision |
|
|
|
|
|
|
|
MAX_GRASCII_LENGTH = 16 |
|
|
|
|
|
@st.cache_data(show_spinner="Loading shorthand images") |
|
def load_images(): |
|
ds = load_dataset( |
|
"grascii/gregg-preanniversary-words", split="train", token=st.secrets.HF_TOKEN |
|
) |
|
image_map = {} |
|
for row in ds: |
|
buffered = BytesIO() |
|
row["image"].save(buffered, format="PNG") |
|
b64 = base64.b64encode(buffered.getvalue()) |
|
image_map[row["longhand"]] = "data:image/png;base64," + b64.decode("utf-8") |
|
return image_map |
|
|
|
|
|
image_map = load_images() |
|
|
|
|
|
def on_submit(): |
|
if "grascii_text_box" in st.session_state: |
|
st.session_state["grascii"] = st.session_state["grascii_text_box"] |
|
st.session_state["alternatives"] = {} |
|
|
|
|
|
def write_grascii_search(): |
|
searcher = GrasciiSearcher() |
|
grascii_results = [] |
|
|
|
search_by = st.radio("Search by", ["text", "image (beta)"], horizontal=True) |
|
|
|
with st.form("Grascii Search"): |
|
placeholder = st.empty() |
|
if search_by == "text": |
|
placeholder.text_input( |
|
"Grascii", |
|
value=st.session_state["grascii"], |
|
key="grascii_text_box", |
|
max_chars=MAX_GRASCII_LENGTH, |
|
help="[Grascii Language Reference](https://grascii.readthedocs.io/en/stable/language.html)", |
|
) |
|
else: |
|
with placeholder.container(): |
|
image_data = st.file_uploader( |
|
"Image", |
|
type=["png", "jpg"], |
|
help=""" |
|
Upload an image of a shorthand form. |
|
|
|
At this time, minimal preprocessing is performed on images |
|
before running them through the model. For best results, |
|
upload an image: |
|
|
|
- of a closely cropped, single shorthand form |
|
- with the shorthand written in black on a white background |
|
- that does not contain marks beside the shorthand form |
|
""", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
if image_data: |
|
image = Image.open(image_data).convert("RGBA") |
|
background = Image.new("RGBA", image.size, (255, 255, 255)) |
|
alpha_composite = Image.alpha_composite(background, image) |
|
|
|
arr = np.array([alpha_composite.convert("L")]) |
|
predictions = run_vision(arr) |
|
alternatives = {"".join(p): True for p in predictions} |
|
if st.session_state["alternatives"] != alternatives: |
|
st.session_state["alternatives"] = alternatives |
|
st.session_state["grascii"] = "".join(predictions[0]) |
|
|
|
|
|
|
|
|
|
with st.expander("Options"): |
|
interpretation = st.radio( |
|
"Interpretation", |
|
["best", "all"], |
|
horizontal=True, |
|
help=""" |
|
How to intepret ambiguous Grascii strings. |
|
|
|
- best: Only search using the best interpretation. |
|
- all: Search using all possible interpretations. |
|
""", |
|
) |
|
uncertainty = st.slider( |
|
"Uncertainty", |
|
min_value=0, |
|
max_value=2, |
|
value=1, |
|
help=""" |
|
The uncertainty of the strokes in the Grascii string. |
|
|
|
A value of at least 1 is recommended for image searches. |
|
""", |
|
) |
|
fix_first = st.checkbox( |
|
"Fix First", help="Apply an uncertainty of 0 to the first token." |
|
) |
|
search_mode = st.selectbox( |
|
"Search Mode", |
|
["match", "start", "contain"], |
|
help=""" |
|
The type of search to perform. |
|
|
|
- match: Search for entries that closely match the Grascii string. |
|
- start: Search for entries that start with the Grascii string. |
|
- contain: Search for entries that contain the Grascii string. |
|
""", |
|
) |
|
annotation_mode = st.selectbox( |
|
"Annotation Mode", |
|
["strict", "retain", "discard"], |
|
index=2, |
|
help=""" |
|
How to handle Grascii annotations. |
|
|
|
- discard: Annotations are discarded. |
|
Search results may contain annotations in any location. |
|
- retain: Annotations in the input must appear in search results. |
|
Other annotations may appear in the results. |
|
- strict: Annotations in the input must appear in search results. |
|
Other annotations may not appear in the results. |
|
""", |
|
) |
|
aspirate_mode = st.selectbox( |
|
"Aspirate Mode", |
|
["strict", "retain", "discard"], |
|
index=2, |
|
help=""" |
|
How to handle Grascii asirates ('). |
|
|
|
- discard: Aspirates are discarded. |
|
Search results may contain aspirates in any location. |
|
- retain: Aspirates in the input must appear in search results. |
|
Other aspirates may appear in the results. |
|
- strict: Aspirates in the input must appear in search results. |
|
Other aspirates may not appear in the results. |
|
""", |
|
) |
|
disjoiner_mode = st.selectbox( |
|
"Disjoiner Mode", |
|
["strict", "retain", "discard"], |
|
index=0, |
|
help=""" |
|
How to handle Grascii disjoiners (^). |
|
|
|
- discard: Disjoiners are discarded. |
|
Search results may contain disjoiners in any location. |
|
- retain: Disjoiners in the input must appear in search results. |
|
Other disjoiners may appear in the results. |
|
- strict: Disjoiners in the input must appear in search results. |
|
Other disjoiners may not appear in the results. |
|
""", |
|
) |
|
|
|
st.form_submit_button("Search", on_click=on_submit) |
|
|
|
grascii = st.session_state["grascii"] |
|
|
|
if len(grascii) > MAX_GRASCII_LENGTH: |
|
st.error(f"Grascii too long. Max: {MAX_GRASCII_LENGTH} characters") |
|
return |
|
|
|
try: |
|
grascii_results = searcher.sorted_search( |
|
grascii=grascii, |
|
interpretation=interpretation, |
|
uncertainty=uncertainty, |
|
fix_first=fix_first, |
|
search_mode=search_mode, |
|
annotation_mode=annotation_mode, |
|
aspirate_mode=aspirate_mode, |
|
disjoiner_mode=disjoiner_mode, |
|
) |
|
except InvalidGrascii as e: |
|
if grascii: |
|
st.error(f"Invalid Grascii\n```\n{e.context}\n```") |
|
else: |
|
if len(st.session_state["alternatives"]) > 1: |
|
st.pills( |
|
"Alternatives", |
|
st.session_state["alternatives"], |
|
key="alternative", |
|
default=grascii, |
|
on_change=on_alternative_selection, |
|
) |
|
write_results(grascii_results, grascii.upper(), "grascii") |
|
|
|
|
|
def on_alternative_selection(): |
|
if st.session_state["alternative"] is None: |
|
st.session_state["alternative"] = st.session_state["grascii"] |
|
else: |
|
st.session_state["grascii"] = st.session_state["alternative"] |
|
|
|
|
|
@st.fragment |
|
def write_results(results, term, key_prefix): |
|
rows = map( |
|
lambda r: [ |
|
r.entry.grascii, |
|
r.entry.translation, |
|
image_map.get(r.entry.translation), |
|
], |
|
results, |
|
) |
|
data = pd.DataFrame(rows) |
|
|
|
r = "Results" if len(data) != 1 else "Result" |
|
st.write(f'{len(data)} {r} for "{term}"') |
|
|
|
event = st.dataframe( |
|
data, |
|
use_container_width=True, |
|
column_config={ |
|
"0": "Grascii", |
|
"1": "Longhand", |
|
"2": st.column_config.ImageColumn("Shorthand", width="medium"), |
|
}, |
|
selection_mode="multi-row", |
|
on_select="rerun", |
|
key=key_prefix + "_data_frame", |
|
hide_index=True, |
|
) |
|
selected_rows = event.selection.rows |
|
|
|
if st.button( |
|
"Flag Selected Rows", |
|
key=key_prefix + "_report_button", |
|
disabled=len(selected_rows) == 0, |
|
): |
|
report_dialog(data.iloc[selected_rows]) |
|
|
|
|
|
def write_reverse_search(): |
|
searcher = ReverseSearcher() |
|
reverse_results = [] |
|
|
|
with st.form("Reverse Search"): |
|
word = st.text_input("Word(s)") |
|
|
|
st.form_submit_button("Search") |
|
|
|
if word: |
|
reverse_results = searcher.sorted_search( |
|
reverse=word, |
|
) |
|
if word: |
|
write_results(reverse_results, word, "reverse") |
|
|