File size: 9,622 Bytes
5bcc73a
 
 
 
 
 
 
 
 
 
f7fafc2
5bcc73a
 
5db88df
 
 
5bcc73a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67bc9b1
5bcc73a
 
67bc9b1
5bcc73a
 
 
 
 
 
 
 
 
 
 
 
67bc9b1
 
 
 
c88ad69
5bcc73a
 
4dee1d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7fafc2
 
 
 
5bcc73a
 
 
 
 
 
 
67bc9b1
 
 
 
 
 
f7fafc2
 
5bcc73a
 
 
 
 
 
 
 
 
5f26241
5bcc73a
 
 
 
 
 
 
5f26241
 
 
 
 
 
5bcc73a
 
5f26241
5bcc73a
 
 
 
 
 
 
5f26241
 
 
5bcc73a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67bc9b1
5bcc73a
 
 
5db88df
 
 
 
5bcc73a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67bc9b1
 
 
 
 
 
 
 
5bcc73a
 
 
67bc9b1
 
 
 
 
 
 
5bcc73a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80a676e
5bcc73a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
import base64
from io import BytesIO
import numpy as np
import streamlit as st
from PIL import Image
import pandas as pd
from datasets import load_dataset
from grascii import GrasciiSearcher, InvalidGrascii, ReverseSearcher
from report import report_dialog
from vision import run_vision
# from save_image import save_image


MAX_GRASCII_LENGTH = 16


@st.cache_data(show_spinner="Loading shorthand images")
def load_images():
    ds = load_dataset(
        "grascii/gregg-preanniversary-words", split="train", token=st.secrets.HF_TOKEN
    )
    image_map = {}
    for row in ds:
        buffered = BytesIO()
        row["image"].save(buffered, format="PNG")
        b64 = base64.b64encode(buffered.getvalue())
        image_map[row["longhand"]] = "data:image/png;base64," + b64.decode("utf-8")
    return image_map


image_map = load_images()


def on_submit():
    if "grascii_text_box" in st.session_state:
        st.session_state["grascii"] = st.session_state["grascii_text_box"]
        st.session_state["alternatives"] = {}


def write_grascii_search():
    searcher = GrasciiSearcher()
    grascii_results = []

    search_by = st.radio("Search by", ["text", "image (beta)"], horizontal=True)

    with st.form("Grascii Search"):
        placeholder = st.empty()
        if search_by == "text":
            placeholder.text_input(
                "Grascii",
                value=st.session_state["grascii"],
                key="grascii_text_box",
                max_chars=MAX_GRASCII_LENGTH,
                help="[Grascii Language Reference](https://grascii.readthedocs.io/en/stable/language.html)",
            )
        else:
            with placeholder.container():
                image_data = st.file_uploader(
                    "Image",
                    type=["png", "jpg"],
                    help="""
                        Upload an image of a shorthand form.

                        At this time, minimal preprocessing is performed on images
                        before running them through the model. For best results,
                        upload an image:

                        - of a closely cropped, single shorthand form
                        - with the shorthand written in black on a white background
                        - that does not contain marks beside the shorthand form
                        """,
                )
                # save = st.checkbox(
                #     "Save images I upload for potential inclusion in open-source datasets used to train and improve models",
                #     key="save_image",
                # )

            if image_data:
                image = Image.open(image_data).convert("RGBA")
                background = Image.new("RGBA", image.size, (255, 255, 255))
                alpha_composite = Image.alpha_composite(background, image)

                arr = np.array([alpha_composite.convert("L")])
                predictions = run_vision(arr)
                alternatives = {"".join(p): True for p in predictions}
                if st.session_state["alternatives"] != alternatives:
                    st.session_state["alternatives"] = alternatives
                    st.session_state["grascii"] = "".join(predictions[0])

                # if save:
                #     save_image(image_data.getvalue(), "-".join(predictions[0]))

        with st.expander("Options"):
            interpretation = st.radio(
                "Interpretation",
                ["best", "all"],
                horizontal=True,
                help="""
                    How to intepret ambiguous Grascii strings.

                    - best: Only search using the best interpretation.
                    - all: Search using all possible interpretations.
                    """,
            )
            uncertainty = st.slider(
                "Uncertainty",
                min_value=0,
                max_value=2,
                value=1,
                help="""
                The uncertainty of the strokes in the Grascii string.

                A value of at least 1 is recommended for image searches.
                """,
            )
            fix_first = st.checkbox(
                "Fix First", help="Apply an uncertainty of 0 to the first token."
            )
            search_mode = st.selectbox(
                "Search Mode",
                ["match", "start", "contain"],
                help="""
                    The type of search to perform.

                    - match: Search for entries that closely match the Grascii string.
                    - start: Search for entries that start with the Grascii string.
                    - contain: Search for entries that contain the Grascii string.
                    """,
            )
            annotation_mode = st.selectbox(
                "Annotation Mode",
                ["strict", "retain", "discard"],
                index=2,
                help="""
                    How to handle Grascii annotations.

                    - discard: Annotations are discarded.
                        Search results may contain annotations in any location.
                    - retain: Annotations in the input must appear in search results.
                        Other annotations may appear in the results.
                    - strict: Annotations in the input must appear in search results.
                        Other annotations may not appear in the results.
                    """,
            )
            aspirate_mode = st.selectbox(
                "Aspirate Mode",
                ["strict", "retain", "discard"],
                index=2,
                help="""
                    How to handle Grascii asirates (').

                    - discard: Aspirates are discarded.
                        Search results may contain aspirates in any location.
                    - retain: Aspirates in the input must appear in search results.
                        Other aspirates may appear in the results.
                    - strict: Aspirates in the input must appear in search results.
                        Other aspirates may not appear in the results.
                    """,
            )
            disjoiner_mode = st.selectbox(
                "Disjoiner Mode",
                ["strict", "retain", "discard"],
                index=0,
                help="""
                    How to handle Grascii disjoiners (^).

                    - discard: Disjoiners are discarded.
                        Search results may contain disjoiners in any location.
                    - retain: Disjoiners in the input must appear in search results.
                        Other disjoiners may appear in the results.
                    - strict: Disjoiners in the input must appear in search results.
                        Other disjoiners may not appear in the results.
                    """,
            )

        st.form_submit_button("Search", on_click=on_submit)

    grascii = st.session_state["grascii"]

    if len(grascii) > MAX_GRASCII_LENGTH:
        st.error(f"Grascii too long. Max: {MAX_GRASCII_LENGTH} characters")
        return

    try:
        grascii_results = searcher.sorted_search(
            grascii=grascii,
            interpretation=interpretation,
            uncertainty=uncertainty,
            fix_first=fix_first,
            search_mode=search_mode,
            annotation_mode=annotation_mode,
            aspirate_mode=aspirate_mode,
            disjoiner_mode=disjoiner_mode,
        )
    except InvalidGrascii as e:
        if grascii:
            st.error(f"Invalid Grascii\n```\n{e.context}\n```")
    else:
        if len(st.session_state["alternatives"]) > 1:
            st.pills(
                "Alternatives",
                st.session_state["alternatives"],
                key="alternative",
                default=grascii,
                on_change=on_alternative_selection,
            )
        write_results(grascii_results, grascii.upper(), "grascii")


def on_alternative_selection():
    if st.session_state["alternative"] is None:
        st.session_state["alternative"] = st.session_state["grascii"]
    else:
        st.session_state["grascii"] = st.session_state["alternative"]


@st.fragment
def write_results(results, term, key_prefix):
    rows = map(
        lambda r: [
            r.entry.grascii,
            r.entry.translation,
            image_map.get(r.entry.translation),
        ],
        results,
    )
    data = pd.DataFrame(rows)

    r = "Results" if len(data) != 1 else "Result"
    st.write(f'{len(data)} {r} for "{term}"')

    event = st.dataframe(
        data,
        use_container_width=True,
        column_config={
            "0": "Grascii",
            "1": "Longhand",
            "2": st.column_config.ImageColumn("Shorthand", width="medium"),
        },
        selection_mode="multi-row",
        on_select="rerun",
        key=key_prefix + "_data_frame",
        hide_index=True,
    )
    selected_rows = event.selection.rows

    if st.button(
        "Flag Selected Rows",
        key=key_prefix + "_report_button",
        disabled=len(selected_rows) == 0,
    ):
        report_dialog(data.iloc[selected_rows])


def write_reverse_search():
    searcher = ReverseSearcher()
    reverse_results = []

    with st.form("Reverse Search"):
        word = st.text_input("Word(s)")

        st.form_submit_button("Search")

        if word:
            reverse_results = searcher.sorted_search(
                reverse=word,
            )
    if word:
        write_results(reverse_results, word, "reverse")