Harmanjotkaur1804 commited on
Commit
61dfbf7
1 Parent(s): bfd6a89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -37
app.py CHANGED
@@ -1,46 +1,46 @@
1
  import streamlit as st
2
- from transformers import T5ForConditionalGeneration, T5Tokenizer
3
- import torch
4
 
5
- # Set up the Streamlit app
6
- st.title("Correct your Grammar with Transformers")
7
  st.write("")
8
  st.write("Input your text here!")
9
 
10
- # Create input text area
11
  default_value = "Mike and Anna is skiing"
12
  sent = st.text_area("Text", default_value, height=50)
 
 
 
 
 
 
 
 
 
13
 
14
- # Create "Check Now" button
 
 
 
 
 
 
 
 
 
 
 
15
  if st.button("Check Now"):
16
- # Run Model
17
- torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
- tokenizer = T5Tokenizer.from_pretrained('deep-learning-analytics/GrammarCorrector')
19
- model = T5ForConditionalGeneration.from_pretrained('deep-learning-analytics/GrammarCorrector').to(torch_device)
20
-
21
- def correct_grammar(input_text, num_return_sequences=1):
22
- batch = tokenizer([input_text], truncation=True, padding='max_length', max_length=len(input_text), return_tensors="pt").to(torch_device)
23
- results = model.generate(**batch, max_length=len(input_text), num_beams=2, num_return_sequences=num_return_sequences, temperature=1.5)
24
- return results
25
-
26
- # Prompts
27
- results = correct_grammar(sent, num_return_sequences=1)
28
-
29
- # Decode results
30
- generated_sequences = []
31
- for generated_sequence_idx, generated_sequence in enumerate(results):
32
- text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)
33
- generated_sequences.append(text)
34
-
35
- # Check correctness
36
- is_correct = sent == generated_sequences[0]
37
-
38
- # Display correctness result
39
- if is_correct:
40
- st.write("Result: ", generated_sequences[0], " (Correct)", key="result_text", unsafe_allow_html=True)
41
- else:
42
- st.write("Result: ", generated_sequences[0], " (Wrong)", key="result_text", unsafe_allow_html=True)
43
-
44
- # Display correct grammar sentence in a box
45
- st.text("Correct Grammar Sentence:")
46
- st.code(generated_sequences[0])
 
1
  import streamlit as st
 
 
2
 
3
+ st.title("Correct Grammar with Transformers ")
 
4
  st.write("")
5
  st.write("Input your text here!")
6
 
 
7
  default_value = "Mike and Anna is skiing"
8
  sent = st.text_area("Text", default_value, height=50)
9
+ num_return_sequences = st.sidebar.number_input('Number of Return Sequences', min_value=1, max_value=3, value=1, step=1)
10
+
11
+ # Run Model
12
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
13
+ import torch
14
+
15
+ torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+ tokenizer = T5Tokenizer.from_pretrained('deep-learning-analytics/GrammarCorrector')
17
+ model = T5ForConditionalGeneration.from_pretrained('deep-learning-analytics/GrammarCorrector').to(torch_device)
18
 
19
+ def correct_grammar(input_text, num_return_sequences=num_return_sequences):
20
+ batch = tokenizer([input_text], truncation=True, padding='max_length', max_length=len(input_text), return_tensors="pt").to(torch_device)
21
+ results = model.generate(**batch, max_length=len(input_text), num_beams=2, num_return_sequences=num_return_sequences, temperature=1.5)
22
+ return results
23
+
24
+ # Prompts
25
+ results = correct_grammar(sent, num_return_sequences)
26
+
27
+ # Decode generated sequences
28
+ generated_sequences = [tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True) for generated_sequence in results]
29
+
30
+ # Add "Check Now" button
31
  if st.button("Check Now"):
32
+ st.write("### Results:")
33
+
34
+ # Check correctness and display in green or red
35
+ for generated_sequence in generated_sequences:
36
+ is_correct = generated_sequence == sent
37
+ color = "green" if is_correct else "red"
38
+ st.write(f"**Generated Sentence:**", generated_sequence, f" (Correct: {is_correct})", unsafe_allow_html=True)
39
+
40
+ # If incorrect, display correct grammar sentence in a box
41
+ if not is_correct:
42
+ st.warning(f"**Correct Grammar:** {sent}")
43
+
44
+ # Display original input
45
+ st.write("### Original Input:")
46
+ st.write(sent)