robinq's picture
use KBLab's swedish-ocr-correction model
4583cb1 verified
raw
history blame
3.8 kB
import streamlit as st
from transformers import AutoTokenizer, T5ForConditionalGeneration
import post_ocr
# Sidebar information
info = '''Welcome to the demo of the [swedish-ocr-correction](https://huggingface.co/viklofg/swedish-ocr-correction) model.
Enter or upload OCR output and the model will attempt to correct it.
:clock2: Slow generation? Try a shorter input.
'''
# Example inputs
examples = {
'Examples': None,
'Example 1': 'En Gosse fur plats nu genast ! inetallyrkc, JU 83 Drottninggatan.',
'Example 2': '— Storartad gåfva till Göteborgs Museum. Den i HandelstidniDgens g&rdagsnnmmer omtalade hvalfisken, sorn fångats i Frölnndaviken, har i dag af hr brukspatronen James Dickson blifvit inköpt för 1,500 rdr och skänkt till härvarande Museum.',
'Example 3': 'Sn underlig race att ſtudera, desfa uppſinnare! utropar en Londontidnings fronifôr. Wet ni hur ſtort antalet är af patenter, ſom ſiſtlidet är utfärdades i British Patent Office? Jo, 14,000 ſty>en !! Det kan man ju fkalla en rif rd! Fjorton tuſen uppfinninnar! Herre Gud, hwilfet märkrwoärdigt tidehrvarf wi lefroa i!'
}
# Load model
@st.cache_resource
def load_model():
return T5ForConditionalGeneration.from_pretrained('KBLab/swedish-ocr-correction')
model = load_model()
# Load tokenizer
@st.cache_resource
def load_tokenizer():
return AutoTokenizer.from_pretrained('google/byt5-small')
tokenizer = load_tokenizer()
# Set model and tokenizer
post_ocr.set_model(model, tokenizer)
# Title
st.title(':memo: Swedish OCR correction')
# Input and output areas
tab1, tab2 = st.tabs(["Text input", "From file"])
# Initialize session states
if 'inputs' not in st.session_state:
st.session_state.inputs = {'tab1': None, 'tab2': None}
if 'outputs' not in st.session_state:
st.session_state.outputs = {'tab1': None, 'tab2': None}
# Sidebar (info)
with st.sidebar:
st.header('About')
st.markdown(info)
def handle_input(input_, id_):
"""Generate and display output"""
# Put everything output-related in a bordered container
with st.container(border=True):
st.caption('Output')
# Only update the output if the input has been updated
if input_ and st.session_state.inputs[id_] != input_:
st.session_state.inputs[id_] = input_
with st.spinner('Generating...'):
output = post_ocr.process(input_)
st.session_state.outputs[id_] = output
# This container is needed to display the `show changes` toggle
# after the output text
container = st.container()
st.divider()
show_changes = st.toggle('Show changes', key=f'show_changes_{id_}')
with container:
# Display output
output = st.session_state.outputs[id_]
if output is not None:
st.write(post_ocr.diff(input_, output) if show_changes else output)
# Manual entry tab
with tab1:
col1, col2 = st.columns([4, 1])
with col2:
example_title = st.selectbox('Examples', options=examples,
label_visibility='collapsed')
with col1:
text = st.text_area(
label='Input text',
value=examples[example_title],
height=200,
label_visibility='collapsed',
placeholder='Enter OCR generated text or choose an example')
if text is not None:
handle_input(text, 'tab1')
# File upload tab
with tab2:
uploaded_file = st.file_uploader('Choose a file', type='.txt')
# Display file content
if uploaded_file is not None:
file_content = uploaded_file.getvalue().decode('utf-8')
text = st.text_area('File content', value=file_content, height=300)
handle_input(text, 'tab2')