import streamlit as st from transformers import T5ForConditionalGeneration, T5Tokenizer import torch # Set up the Streamlit app st.title("Correct your Grammar with Transformers") st.write("") st.write("Input your text here!") # Create input text area default_value = "Mike and Anna is skiing" sent = st.text_area("Text", default_value, height=50) # Create "Check Now" button if st.button("Check Now"): # Run Model torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' tokenizer = T5Tokenizer.from_pretrained('deep-learning-analytics/GrammarCorrector') model = T5ForConditionalGeneration.from_pretrained('deep-learning-analytics/GrammarCorrector').to(torch_device) def correct_grammar(input_text, num_return_sequences=1): batch = tokenizer([input_text], truncation=True, padding='max_length', max_length=len(input_text), return_tensors="pt").to(torch_device) results = model.generate(**batch, max_length=len(input_text), num_beams=2, num_return_sequences=num_return_sequences, temperature=1.5) return results # Prompts results = correct_grammar(sent, num_return_sequences=1) # Decode results generated_sequences = [] for generated_sequence_idx, generated_sequence in enumerate(results): text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True) generated_sequences.append(text) # Check correctness is_correct = sent == generated_sequences[0] # Display correctness result if is_correct: st.write("Result: ", generated_sequences[0], " (Correct)", key="result_text", unsafe_allow_html=True) else: st.write("Result: ", generated_sequences[0], " (Wrong)", key="result_text", unsafe_allow_html=True) # Display correct grammar sentence in a box st.text("Correct Grammar Sentence:") st.code(generated_sequences[0])