File size: 1,912 Bytes
a61e741
 
 
 
 
 
 
0180518
a61e741
 
 
 
 
b797ac8
a61e741
 
 
 
 
 
 
 
 
 
 
 
 
 
4742645
1085827
a61e741
 
 
 
 
 
 
4742645
1085827
a61e741
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import streamlit as st
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Load your fine-tuned FLAN-T5 model and tokenizer
@st.cache_resource  
def load_model():
    tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
    model = T5ForConditionalGeneration.from_pretrained("thaboe01/t5-spelling-corrector")
    return tokenizer, model

# Load model (only once)
tokenizer, model = load_model()

MAX_PHRASE_LENGTH = 3  
PREFIX = "Please correct the following sentence: "

# Function to correct text
def correct_text(text):
    words = text.split()
    corrected_phrases = []
    current_chunk = []

    for word in words:
        current_chunk.append(word)
        # Check if adding the next word would exceed max length (including prefix)
        if len(current_chunk) + 1 > MAX_PHRASE_LENGTH:  
            input_text = PREFIX + " ".join(current_chunk)
            input_ids = tokenizer(input_text, return_tensors="pt").input_ids
            outputs = model.generate(input_ids)
            corrected_phrase = tokenizer.decode(outputs[0], skip_special_tokens=True)
            corrected_phrases.append(corrected_phrase)
            current_chunk = []  # Reset the chunk

    # Handle the last chunk
    if current_chunk:
        input_text = PREFIX + " ".join(current_chunk)
        input_ids = tokenizer(input_text, return_tensors="pt").input_ids
        outputs = model.generate(input_ids)
        corrected_phrase = tokenizer.decode(outputs[0], skip_special_tokens=True)
        corrected_phrases.append(corrected_phrase)

    return " ".join(corrected_phrases)  # Join the corrected chunks


# Streamlit App
st.title("Shona Text Editor with Real-Time Spelling Correction")
text_input = st.text_area("Start typing here...", height=250)

if text_input:
    corrected_text = correct_text(text_input)
    st.text_area("Corrected Text", value=corrected_text, height=250, disabled=True)