import base64 from io import BytesIO import numpy as np import streamlit as st from PIL import Image import pandas as pd from datasets import load_dataset from grascii import GrasciiSearcher, InvalidGrascii, ReverseSearcher from report import report_dialog from vision import run_vision # from save_image import save_image MAX_GRASCII_LENGTH = 16 @st.cache_data(show_spinner="Loading shorthand images") def load_images(): ds = load_dataset( "grascii/gregg-preanniversary-words", split="train", token=st.secrets.HF_TOKEN ) image_map = {} for row in ds: buffered = BytesIO() row["image"].save(buffered, format="PNG") b64 = base64.b64encode(buffered.getvalue()) image_map[row["longhand"]] = "data:image/png;base64," + b64.decode("utf-8") return image_map image_map = load_images() def on_submit(): if "grascii_text_box" in st.session_state: st.session_state["grascii"] = st.session_state["grascii_text_box"] st.session_state["alternatives"] = {} def write_grascii_search(): searcher = GrasciiSearcher() grascii_results = [] search_by = st.radio("Search by", ["text", "image (beta)"], horizontal=True) with st.form("Grascii Search"): placeholder = st.empty() if search_by == "text": placeholder.text_input( "Grascii", value=st.session_state["grascii"], key="grascii_text_box", max_chars=MAX_GRASCII_LENGTH, help="[Grascii Language Reference](https://grascii.readthedocs.io/en/stable/language.html)", ) else: with placeholder.container(): image_data = st.file_uploader( "Image", type=["png", "jpg"], help=""" Upload an image of a shorthand form. At this time, minimal preprocessing is performed on images before running them through the model. For best results, upload an image: - of a closely cropped, single shorthand form - with the shorthand written in black on a white background - that does not contain marks beside the shorthand form """, ) # save = st.checkbox( # "Save images I upload for potential inclusion in open-source datasets used to train and improve models", # key="save_image", # ) if image_data: image = Image.open(image_data).convert("RGBA") background = Image.new("RGBA", image.size, (255, 255, 255)) alpha_composite = Image.alpha_composite(background, image) arr = np.array([alpha_composite.convert("L")]) predictions = run_vision(arr) alternatives = {"".join(p): True for p in predictions} if st.session_state["alternatives"] != alternatives: st.session_state["alternatives"] = alternatives st.session_state["grascii"] = "".join(predictions[0]) # if save: # save_image(image_data.getvalue(), "-".join(predictions[0])) with st.expander("Options"): interpretation = st.radio( "Interpretation", ["best", "all"], horizontal=True, help=""" How to intepret ambiguous Grascii strings. - best: Only search using the best interpretation. - all: Search using all possible interpretations. """, ) uncertainty = st.slider( "Uncertainty", min_value=0, max_value=2, value=1, help=""" The uncertainty of the strokes in the Grascii string. A value of at least 1 is recommended for image searches. """, ) fix_first = st.checkbox( "Fix First", help="Apply an uncertainty of 0 to the first token." ) search_mode = st.selectbox( "Search Mode", ["match", "start", "contain"], help=""" The type of search to perform. - match: Search for entries that closely match the Grascii string. - start: Search for entries that start with the Grascii string. - contain: Search for entries that contain the Grascii string. """, ) annotation_mode = st.selectbox( "Annotation Mode", ["strict", "retain", "discard"], index=2, help=""" How to handle Grascii annotations. - discard: Annotations are discarded. Search results may contain annotations in any location. - retain: Annotations in the input must appear in search results. Other annotations may appear in the results. - strict: Annotations in the input must appear in search results. Other annotations may not appear in the results. """, ) aspirate_mode = st.selectbox( "Aspirate Mode", ["strict", "retain", "discard"], index=2, help=""" How to handle Grascii asirates ('). - discard: Aspirates are discarded. Search results may contain aspirates in any location. - retain: Aspirates in the input must appear in search results. Other aspirates may appear in the results. - strict: Aspirates in the input must appear in search results. Other aspirates may not appear in the results. """, ) disjoiner_mode = st.selectbox( "Disjoiner Mode", ["strict", "retain", "discard"], index=0, help=""" How to handle Grascii disjoiners (^). - discard: Disjoiners are discarded. Search results may contain disjoiners in any location. - retain: Disjoiners in the input must appear in search results. Other disjoiners may appear in the results. - strict: Disjoiners in the input must appear in search results. Other disjoiners may not appear in the results. """, ) st.form_submit_button("Search", on_click=on_submit) grascii = st.session_state["grascii"] if len(grascii) > MAX_GRASCII_LENGTH: st.error(f"Grascii too long. Max: {MAX_GRASCII_LENGTH} characters") return try: grascii_results = searcher.sorted_search( grascii=grascii, interpretation=interpretation, uncertainty=uncertainty, fix_first=fix_first, search_mode=search_mode, annotation_mode=annotation_mode, aspirate_mode=aspirate_mode, disjoiner_mode=disjoiner_mode, ) except InvalidGrascii as e: if grascii: st.error(f"Invalid Grascii\n```\n{e.context}\n```") else: if len(st.session_state["alternatives"]) > 1: st.pills( "Alternatives", st.session_state["alternatives"], key="alternative", default=grascii, on_change=on_alternative_selection, ) write_results(grascii_results, grascii.upper(), "grascii") def on_alternative_selection(): if st.session_state["alternative"] is None: st.session_state["alternative"] = st.session_state["grascii"] else: st.session_state["grascii"] = st.session_state["alternative"] @st.fragment def write_results(results, term, key_prefix): rows = map( lambda r: [ r.entry.grascii, r.entry.translation, image_map.get(r.entry.translation), ], results, ) data = pd.DataFrame(rows) r = "Results" if len(data) != 1 else "Result" st.write(f'{len(data)} {r} for "{term}"') event = st.dataframe( data, use_container_width=True, column_config={ "0": "Grascii", "1": "Longhand", "2": st.column_config.ImageColumn("Shorthand", width="medium"), }, selection_mode="multi-row", on_select="rerun", key=key_prefix + "_data_frame", hide_index=True, ) selected_rows = event.selection.rows if st.button( "Flag Selected Rows", key=key_prefix + "_report_button", disabled=len(selected_rows) == 0, ): report_dialog(data.iloc[selected_rows]) def write_reverse_search(): searcher = ReverseSearcher() reverse_results = [] with st.form("Reverse Search"): word = st.text_input("Word(s)") st.form_submit_button("Search") if word: reverse_results = searcher.sorted_search( reverse=word, ) if word: write_results(reverse_results, word, "reverse")