thaboe01's picture
Update app.py
b797ac8 verified
import streamlit as st
from transformers import T5Tokenizer, T5ForConditionalGeneration
# Load your fine-tuned FLAN-T5 model and tokenizer
@st.cache_resource
def load_model():
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("thaboe01/t5-spelling-corrector")
return tokenizer, model
# Load model (only once)
tokenizer, model = load_model()
MAX_PHRASE_LENGTH = 3
PREFIX = "Please correct the following sentence: "
# Function to correct text
def correct_text(text):
words = text.split()
corrected_phrases = []
current_chunk = []
for word in words:
current_chunk.append(word)
# Check if adding the next word would exceed max length (including prefix)
if len(current_chunk) + 1 > MAX_PHRASE_LENGTH:
input_text = PREFIX + " ".join(current_chunk)
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
outputs = model.generate(input_ids)
corrected_phrase = tokenizer.decode(outputs[0], skip_special_tokens=True)
corrected_phrases.append(corrected_phrase)
current_chunk = [] # Reset the chunk
# Handle the last chunk
if current_chunk:
input_text = PREFIX + " ".join(current_chunk)
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
outputs = model.generate(input_ids)
corrected_phrase = tokenizer.decode(outputs[0], skip_special_tokens=True)
corrected_phrases.append(corrected_phrase)
return " ".join(corrected_phrases) # Join the corrected chunks
# Streamlit App
st.title("Shona Text Editor with Real-Time Spelling Correction")
text_input = st.text_area("Start typing here...", height=250)
if text_input:
corrected_text = correct_text(text_input)
st.text_area("Corrected Text", value=corrected_text, height=250, disabled=True)