Harmanjotkaur1804 commited on
Commit
bfd6a89
1 Parent(s): fd3567b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -27
app.py CHANGED
@@ -1,34 +1,46 @@
1
  import streamlit as st
 
 
2
 
3
-
4
- st.title("Correct Grammar with Transformers 🦄")
5
  st.write("")
6
  st.write("Input your text here!")
7
 
 
8
  default_value = "Mike and Anna is skiing"
9
- sent = st.text_area("Text", default_value, height = 50)
10
- num_return_sequences = st.sidebar.number_input('Number of Return Sequences', min_value=1, max_value=3, value=1, step=1)
11
 
12
- ### Run Model
13
- from transformers import T5ForConditionalGeneration, T5Tokenizer
14
- import torch
15
- torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
- tokenizer = T5Tokenizer.from_pretrained('deep-learning-analytics/GrammarCorrector')
17
- model = T5ForConditionalGeneration.from_pretrained('deep-learning-analytics/GrammarCorrector').to(torch_device)
18
-
19
- def correct_grammar(input_text,num_return_sequences=num_return_sequences):
20
- batch = tokenizer([input_text],truncation=True,padding='max_length',max_length=len(input_text), return_tensors="pt").to(torch_device)
21
- results = model.generate(**batch,max_length=len(input_text),num_beams=2, num_return_sequences=num_return_sequences, temperature=1.5)
22
- #answer = tokenizer.batch_decode(results[0], skip_special_tokens=True)
23
- return results
24
-
25
- ##Prompts
26
- results = correct_grammar(sent, num_return_sequences)
27
-
28
- generated_sequences = []
29
- for generated_sequence_idx, generated_sequence in enumerate(results):
30
- # Decode text
31
- text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)
32
- generated_sequences.append(text)
33
-
34
- st.write(generated_sequences)
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
3
+ import torch
4
 
5
+ # Set up the Streamlit app
6
+ st.title("Correct your Grammar with Transformers")
7
  st.write("")
8
  st.write("Input your text here!")
9
 
10
+ # Create input text area
11
  default_value = "Mike and Anna is skiing"
12
+ sent = st.text_area("Text", default_value, height=50)
 
13
 
14
+ # Create "Check Now" button
15
+ if st.button("Check Now"):
16
+ # Run Model
17
+ torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+ tokenizer = T5Tokenizer.from_pretrained('deep-learning-analytics/GrammarCorrector')
19
+ model = T5ForConditionalGeneration.from_pretrained('deep-learning-analytics/GrammarCorrector').to(torch_device)
20
+
21
+ def correct_grammar(input_text, num_return_sequences=1):
22
+ batch = tokenizer([input_text], truncation=True, padding='max_length', max_length=len(input_text), return_tensors="pt").to(torch_device)
23
+ results = model.generate(**batch, max_length=len(input_text), num_beams=2, num_return_sequences=num_return_sequences, temperature=1.5)
24
+ return results
25
+
26
+ # Prompts
27
+ results = correct_grammar(sent, num_return_sequences=1)
28
+
29
+ # Decode results
30
+ generated_sequences = []
31
+ for generated_sequence_idx, generated_sequence in enumerate(results):
32
+ text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)
33
+ generated_sequences.append(text)
34
+
35
+ # Check correctness
36
+ is_correct = sent == generated_sequences[0]
37
+
38
+ # Display correctness result
39
+ if is_correct:
40
+ st.write("Result: ", generated_sequences[0], " (Correct)", key="result_text", unsafe_allow_html=True)
41
+ else:
42
+ st.write("Result: ", generated_sequences[0], " (Wrong)", key="result_text", unsafe_allow_html=True)
43
+
44
+ # Display correct grammar sentence in a box
45
+ st.text("Correct Grammar Sentence:")
46
+ st.code(generated_sequences[0])