|
import streamlit as st |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small") |
|
model = T5ForConditionalGeneration.from_pretrained("thaboe01/t5-spelling-corrector") |
|
return tokenizer, model |
|
|
|
|
|
tokenizer, model = load_model() |
|
|
|
MAX_PHRASE_LENGTH = 3 |
|
PREFIX = "Please correct the following sentence: " |
|
|
|
|
|
def correct_text(text): |
|
words = text.split() |
|
corrected_phrases = [] |
|
current_chunk = [] |
|
|
|
for word in words: |
|
current_chunk.append(word) |
|
|
|
if len(current_chunk) + 1 > MAX_PHRASE_LENGTH: |
|
input_text = PREFIX + " ".join(current_chunk) |
|
input_ids = tokenizer(input_text, return_tensors="pt").input_ids |
|
outputs = model.generate(input_ids) |
|
corrected_phrase = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
corrected_phrases.append(corrected_phrase) |
|
current_chunk = [] |
|
|
|
|
|
if current_chunk: |
|
input_text = PREFIX + " ".join(current_chunk) |
|
input_ids = tokenizer(input_text, return_tensors="pt").input_ids |
|
outputs = model.generate(input_ids) |
|
corrected_phrase = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
corrected_phrases.append(corrected_phrase) |
|
|
|
return " ".join(corrected_phrases) |
|
|
|
|
|
|
|
st.title("Shona Text Editor with Real-Time Spelling Correction") |
|
text_input = st.text_area("Start typing here...", height=250) |
|
|
|
if text_input: |
|
corrected_text = correct_text(text_input) |
|
st.text_area("Corrected Text", value=corrected_text, height=250, disabled=True) |