|
import streamlit as st |
|
from transformers import AutoTokenizer, T5ForConditionalGeneration |
|
import post_ocr |
|
|
|
|
|
info = '''Welcome to the demo of the [swedish-ocr-correction](https://huggingface.co/viklofg/swedish-ocr-correction) model. |
|
|
|
Enter or upload OCR output and the model will attempt to correct it. |
|
|
|
:clock2: Slow generation? Try a shorter input. |
|
''' |
|
|
|
|
|
examples = { |
|
'Examples': None, |
|
'Example 1': 'En Gosse fur plats nu genast ! inetallyrkc, JU 83 Drottninggatan.', |
|
'Example 2': '— Storartad gåfva till Göteborgs Museum. Den i HandelstidniDgens g&rdagsnnmmer omtalade hvalfisken, sorn fångats i Frölnndaviken, har i dag af hr brukspatronen James Dickson blifvit inköpt för 1,500 rdr och skänkt till härvarande Museum.', |
|
'Example 3': 'Sn underlig race att ſtudera, desfa uppſinnare! utropar en Londontidnings fronifôr. Wet ni hur ſtort antalet är af patenter, ſom ſiſtlidet är utfärdades i British Patent Office? Jo, 14,000 ſty>en !! Det kan man ju fkalla en rif rd! Fjorton tuſen uppfinninnar! Herre Gud, hwilfet märkrwoärdigt tidehrvarf wi lefroa i!' |
|
} |
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
return T5ForConditionalGeneration.from_pretrained('KBLab/swedish-ocr-correction') |
|
model = load_model() |
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_tokenizer(): |
|
return AutoTokenizer.from_pretrained('google/byt5-small') |
|
tokenizer = load_tokenizer() |
|
|
|
|
|
|
|
post_ocr.set_model(model, tokenizer) |
|
|
|
|
|
|
|
st.title(':memo: Swedish OCR correction') |
|
|
|
|
|
|
|
tab1, tab2 = st.tabs(["Text input", "From file"]) |
|
|
|
|
|
|
|
if 'inputs' not in st.session_state: |
|
st.session_state.inputs = {'tab1': None, 'tab2': None} |
|
|
|
|
|
if 'outputs' not in st.session_state: |
|
st.session_state.outputs = {'tab1': None, 'tab2': None} |
|
|
|
|
|
|
|
with st.sidebar: |
|
st.header('About') |
|
st.markdown(info) |
|
|
|
|
|
def handle_input(input_, id_): |
|
"""Generate and display output""" |
|
|
|
|
|
with st.container(border=True): |
|
st.caption('Output') |
|
|
|
|
|
if input_ and st.session_state.inputs[id_] != input_: |
|
st.session_state.inputs[id_] = input_ |
|
with st.spinner('Generating...'): |
|
output = post_ocr.process(input_) |
|
st.session_state.outputs[id_] = output |
|
|
|
|
|
|
|
container = st.container() |
|
st.divider() |
|
show_changes = st.toggle('Show changes', key=f'show_changes_{id_}') |
|
|
|
with container: |
|
|
|
output = st.session_state.outputs[id_] |
|
if output is not None: |
|
st.write(post_ocr.diff(input_, output) if show_changes else output) |
|
|
|
|
|
|
|
with tab1: |
|
col1, col2 = st.columns([4, 1]) |
|
|
|
with col2: |
|
example_title = st.selectbox('Examples', options=examples, |
|
label_visibility='collapsed') |
|
|
|
with col1: |
|
text = st.text_area( |
|
label='Input text', |
|
value=examples[example_title], |
|
height=200, |
|
label_visibility='collapsed', |
|
placeholder='Enter OCR generated text or choose an example') |
|
|
|
if text is not None: |
|
handle_input(text, 'tab1') |
|
|
|
|
|
|
|
with tab2: |
|
uploaded_file = st.file_uploader('Choose a file', type='.txt') |
|
|
|
|
|
if uploaded_file is not None: |
|
file_content = uploaded_file.getvalue().decode('utf-8') |
|
text = st.text_area('File content', value=file_content, height=300) |
|
handle_input(text, 'tab2') |
|
|