init
Browse files- app.py +66 -0
- cleaned_text.py +58 -0
- distractor.py +78 -0
- models.py +7 -0
- pipeline.py +82 -0
- qagenerator.py +71 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import random
|
3 |
+
from pipeline import Pipeline
|
4 |
+
|
5 |
+
st.header("Generate Multiple Choice QA Generation")
|
6 |
+
st.markdown(
|
7 |
+
"I built this project based on this [paper](https://www.sciencedirect.com/science/article/pii/S0957417422014014#s0015), "
|
8 |
+
"where they created End-to-End generation of Multiple-Choice questions using Text-to-Text transfer Transformer models (T5).\n\n"
|
9 |
+
"This research focuses on using Transformer-based language models to automate the generation of multiple-choice questions (MCQs), "
|
10 |
+
"with the aim of assisting or assisting educators in the process of creating reading comprehension (RC) assessments. "
|
11 |
+
"This is relevant and timely as teachers can invest less time doing routine work and share more time with their students, "
|
12 |
+
"thus building an engaging experience for face-to-face classroom interaction. "
|
13 |
+
"This study addresses the issue of creating multiple-choice questionnaires from 3 viewpoints: QG, QA, and distractor generation (DG). "
|
14 |
+
"An end-to-end pipeline for generating multiple-choice questions is proposed, based on a pre-trained T5 language model."
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
st.sidebar.info(
|
20 |
+
"Note: The number of questions generated depends on the length of the context. "
|
21 |
+
"You may find that the number of QA pairs does not match the number you want."
|
22 |
+
)
|
23 |
+
|
24 |
+
with st.sidebar:
|
25 |
+
if "num_qa" not in st.session_state:
|
26 |
+
st.session_state.num_qa = 5
|
27 |
+
|
28 |
+
def on_change():
|
29 |
+
st.session_state.num_qa = num_qa
|
30 |
+
|
31 |
+
num_qa = st.slider("Select Number of QA questions", min_value=1, max_value=10, value=1, step=1, on_change=on_change)
|
32 |
+
|
33 |
+
if 'context' not in st.session_state:
|
34 |
+
st.session_state.context = ""
|
35 |
+
st_text_area = st.text_area('Context to generate the QA', value=st.session_state.context, height=500)
|
36 |
+
|
37 |
+
def generate_qa():
|
38 |
+
st.session_state.context = st_text_area
|
39 |
+
mcq_generator = Pipeline()
|
40 |
+
generator = mcq_generator.generate_mcqs(st_text_area, num_qa)
|
41 |
+
st.session_state.generator = generator
|
42 |
+
|
43 |
+
# generate qa button
|
44 |
+
st_generate_button = st.button('Generate', on_click=generate_qa)
|
45 |
+
|
46 |
+
# Display generated MCQs in Streamlit
|
47 |
+
if hasattr(st.session_state, 'generator') and len(st.session_state.generator) > 0:
|
48 |
+
st.subheader("Generated MCQs")
|
49 |
+
for i, question in enumerate(st.session_state.generator, start=1):
|
50 |
+
correct_answer = [question.answerText]
|
51 |
+
distractors_subset = question.distractors[:3] # Assuming you want 3 distractors
|
52 |
+
options = correct_answer + distractors_subset
|
53 |
+
|
54 |
+
# Shuffle options
|
55 |
+
random.shuffle(options)
|
56 |
+
|
57 |
+
options_with_labels = [{'label': chr(ord('A') + j), 'text': option} for j, option in enumerate(options)]
|
58 |
+
|
59 |
+
st.write(f'Number {i}: {question.questionText}')
|
60 |
+
for option in options_with_labels:
|
61 |
+
if option["text"] == correct_answer[0]:
|
62 |
+
st.write(f'<span style="color:green;">{option["label"]}. {option["text"]}</span>', unsafe_allow_html=True)
|
63 |
+
else:
|
64 |
+
st.write(f'{option["label"]}. {option["text"]}')
|
65 |
+
st.write('-------------------')
|
66 |
+
|
cleaned_text.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import string
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
|
6 |
+
def normalize_item(item) -> str:
|
7 |
+
"""Lower text and remove punctuation, articles and extra whitespace."""
|
8 |
+
def remove_articles(text):
|
9 |
+
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
10 |
+
|
11 |
+
def white_space_fix(text):
|
12 |
+
return ' '.join(text.split())
|
13 |
+
|
14 |
+
def remove_punc(text):
|
15 |
+
exclude = set(string.punctuation)
|
16 |
+
return ''.join(ch for ch in text if ch not in exclude)
|
17 |
+
|
18 |
+
def lower(text):
|
19 |
+
return text.lower()
|
20 |
+
|
21 |
+
return white_space_fix(remove_articles(remove_punc(lower(item))))
|
22 |
+
|
23 |
+
|
24 |
+
def remove_duplicates(items: List[str]) -> List[str]:
|
25 |
+
unique_items = []
|
26 |
+
normalized_unique_items = []
|
27 |
+
|
28 |
+
for item in items:
|
29 |
+
normalized_item = normalize_item(item)
|
30 |
+
|
31 |
+
if normalized_item not in normalized_unique_items:
|
32 |
+
unique_items.append(item)
|
33 |
+
normalized_unique_items.append(normalized_item)
|
34 |
+
|
35 |
+
return unique_items
|
36 |
+
|
37 |
+
def remove_distractors_duplicate_with_correct_answer(correct: str, distractors: List[str]) -> List[str]:
|
38 |
+
normalized_correct = normalize_item(correct)
|
39 |
+
|
40 |
+
filtered_distractors = []
|
41 |
+
|
42 |
+
for distractor in distractors:
|
43 |
+
if normalize_item(distractor) != normalized_correct:
|
44 |
+
filtered_distractors.append(distractor)
|
45 |
+
|
46 |
+
return filtered_distractors
|
47 |
+
|
48 |
+
def clean_text(text: str) -> str:
|
49 |
+
# remove brackets
|
50 |
+
cleaned_text = re.sub(r"\((.*?)\)", lambda L: "", text)
|
51 |
+
# remove square bracket
|
52 |
+
cleaned_text = re.sub(r"\[(.*?)\]", lambda L: "", cleaned_text)
|
53 |
+
# remove multiple space
|
54 |
+
cleaned_text = re.sub(" +", " ", cleaned_text)
|
55 |
+
# replace weird hypen
|
56 |
+
cleaned_text = cleaned_text.replace('–', '-')
|
57 |
+
|
58 |
+
return cleaned_text
|
distractor.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import T5TokenizerFast, T5ForConditionalGeneration
|
2 |
+
import string
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
|
6 |
+
SOURCE_MAX_TOKEN_LEN = 512
|
7 |
+
TARGET_MAX_TOKEN_LEN = 50
|
8 |
+
SEP_TOKEN = "[SEP]"
|
9 |
+
MODEL_NAME = "t5-small"
|
10 |
+
|
11 |
+
# Definisi kelas DistractorGenerator
|
12 |
+
class DistractorGenerator:
|
13 |
+
def __init__(self):
|
14 |
+
self.tokenizer = T5TokenizerFast.from_pretrained(MODEL_NAME)
|
15 |
+
self.tokenizer.add_tokens(SEP_TOKEN)
|
16 |
+
self.tokenizer_len = len(self.tokenizer)
|
17 |
+
self.model = T5ForConditionalGeneration.from_pretrained("fahmiaziz/QDModel")
|
18 |
+
|
19 |
+
def generate(self, generate_count: int, correct: str, question: str, context: str) -> List[str]:
|
20 |
+
model_output = self._model_predict(generate_count, correct, question, context)
|
21 |
+
|
22 |
+
cleaned_result = model_output.replace('<pad>', '').replace('</s>', ',')
|
23 |
+
cleaned_result = self._replace_all_extra_id(cleaned_result)
|
24 |
+
distractors = cleaned_result.split(",")[:-1]
|
25 |
+
distractors = [x.translate(str.maketrans('', '', string.punctuation)) for x in distractors]
|
26 |
+
distractors = list(map(lambda x: x.strip(), distractors))
|
27 |
+
|
28 |
+
return distractors
|
29 |
+
|
30 |
+
def _model_predict(self, generate_count: int, correct: str, question: str, context: str) -> str:
|
31 |
+
source_encoding = self.tokenizer(
|
32 |
+
'{} {} {} {} {}'.format(correct, SEP_TOKEN, question, SEP_TOKEN, context),
|
33 |
+
max_length=SOURCE_MAX_TOKEN_LEN,
|
34 |
+
padding='max_length',
|
35 |
+
truncation=True,
|
36 |
+
return_attention_mask=True,
|
37 |
+
add_special_tokens=True,
|
38 |
+
return_tensors='pt'
|
39 |
+
)
|
40 |
+
|
41 |
+
generated_ids = self.model.generate(
|
42 |
+
input_ids=source_encoding['input_ids'],
|
43 |
+
attention_mask=source_encoding['attention_mask'],
|
44 |
+
num_beams=generate_count,
|
45 |
+
num_return_sequences=generate_count,
|
46 |
+
max_length=TARGET_MAX_TOKEN_LEN,
|
47 |
+
repetition_penalty=2.5,
|
48 |
+
length_penalty=1.0,
|
49 |
+
early_stopping=True,
|
50 |
+
use_cache=True
|
51 |
+
)
|
52 |
+
|
53 |
+
preds = {
|
54 |
+
self.tokenizer.decode(generated_id, skip_special_tokens=False, clean_up_tokenization_spaces=True)
|
55 |
+
for generated_id in generated_ids
|
56 |
+
}
|
57 |
+
|
58 |
+
return ''.join(preds)
|
59 |
+
|
60 |
+
def _correct_index_of(self, text: str, substring: str, start_index: int = 0):
|
61 |
+
try:
|
62 |
+
index = text.index(substring, start_index)
|
63 |
+
except ValueError:
|
64 |
+
index = -1
|
65 |
+
|
66 |
+
return index
|
67 |
+
|
68 |
+
def _replace_all_extra_id(self, text: str):
|
69 |
+
new_text = text
|
70 |
+
start_index_of_extra_id = 0
|
71 |
+
|
72 |
+
while (self._correct_index_of(new_text, '<extra_id_') >= 0):
|
73 |
+
start_index_of_extra_id = self._correct_index_of(new_text, '<extra_id_', start_index_of_extra_id)
|
74 |
+
end_index_of_extra_id = self._correct_index_of(new_text, '>', start_index_of_extra_id)
|
75 |
+
|
76 |
+
new_text = new_text[:start_index_of_extra_id] + '[SEP]' + new_text[end_index_of_extra_id + 1:]
|
77 |
+
|
78 |
+
return new_text
|
models.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
class Question:
|
4 |
+
def __init__(self, answerText:str, questionText: str = '', distractors: List[str] = []):
|
5 |
+
self.answerText = answerText
|
6 |
+
self.questionText = questionText
|
7 |
+
self.distractors = distractors
|
pipeline.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from nltk.tokenize import sent_tokenize
|
3 |
+
import toolz
|
4 |
+
|
5 |
+
from models import Question
|
6 |
+
from cleaned_text import clean_text, remove_duplicates, remove_distractors_duplicate_with_correct_answer
|
7 |
+
from distractor import DistractorGenerator
|
8 |
+
from qagenerator import QuestionAnswerGenerator
|
9 |
+
|
10 |
+
|
11 |
+
class Pipeline:
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
|
15 |
+
self.question_generator = QuestionAnswerGenerator()
|
16 |
+
self.distractor_generator = DistractorGenerator()
|
17 |
+
|
18 |
+
# <======================= Main Function =============================>
|
19 |
+
def generate_mcqs(self, context: str, desired_count: int) -> List[Question]:
|
20 |
+
|
21 |
+
cleaned_text = clean_text(context)
|
22 |
+
questions = self._generate_question_answer_pairs(cleaned_text, desired_count)
|
23 |
+
questions = self._generate_distractors(cleaned_text, questions)
|
24 |
+
|
25 |
+
return questions
|
26 |
+
# <====================================================>
|
27 |
+
|
28 |
+
|
29 |
+
# number: 1
|
30 |
+
def _generate_question_answer_pairs(self, context: str, desired_count: int) -> List[Question]:
|
31 |
+
context_splits = self._split_context_according_to_desired_count(context, desired_count)
|
32 |
+
|
33 |
+
questions = []
|
34 |
+
|
35 |
+
for split in context_splits:
|
36 |
+
answer, question = self.question_generator.generate_qna(split)
|
37 |
+
questions.append(Question(answer.capitalize(), question))
|
38 |
+
|
39 |
+
questions = list(toolz.unique(questions, key=lambda x: x.answerText))
|
40 |
+
|
41 |
+
return questions
|
42 |
+
|
43 |
+
# number: 2
|
44 |
+
def _generate_distractors(self, context: str, questions: List[Question]) -> List[Question]:
|
45 |
+
for question in questions:
|
46 |
+
t5_distractors = self.distractor_generator.generate(5, question.answerText, question.questionText, context)
|
47 |
+
|
48 |
+
distractors = remove_duplicates(t5_distractors)
|
49 |
+
distractors = remove_distractors_duplicate_with_correct_answer(question.answerText, distractors)
|
50 |
+
|
51 |
+
#TODO - filter distractors having a similar bleu score with another distractor
|
52 |
+
# filter_distractors = []
|
53 |
+
# for dist in distractors:
|
54 |
+
# bleu_score = self._calculate_nltk_bleu([dist], question.answerText)
|
55 |
+
# if bleu_score > 0.1:
|
56 |
+
# filter_distractors.append(dist)
|
57 |
+
# <=================Need Improve Model=================>
|
58 |
+
|
59 |
+
question.distractors = distractors
|
60 |
+
return questions
|
61 |
+
|
62 |
+
# Helper functions
|
63 |
+
def _split_context_according_to_desired_count(self, context: str, desired_count: int) -> List[str]:
|
64 |
+
sents = sent_tokenize(context)
|
65 |
+
total_sents = len(sents)
|
66 |
+
|
67 |
+
if total_sents <= desired_count:
|
68 |
+
return sents # No need to split if the desired count is greater than or equal to the total sentences.
|
69 |
+
|
70 |
+
sentences_per_split = total_sents // desired_count
|
71 |
+
remainder = total_sents % desired_count # Handle the remaining sentences.
|
72 |
+
|
73 |
+
context_splits = []
|
74 |
+
start_sent_index = 0
|
75 |
+
|
76 |
+
for i in range(desired_count):
|
77 |
+
end_sent_index = start_sent_index + sentences_per_split + (1 if i < remainder else 0)
|
78 |
+
context_split = ' '.join(sents[start_sent_index:end_sent_index])
|
79 |
+
context_splits.append(context_split)
|
80 |
+
start_sent_index = end_sent_index
|
81 |
+
|
82 |
+
return context_splits
|
qagenerator.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Tuple
|
2 |
+
from transformers import T5TokenizerFast, T5ForConditionalGeneration
|
3 |
+
import string
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
# Constants
|
7 |
+
MODEL_NAME = 't5-small'
|
8 |
+
SOURCE_MAX_TOKEN_LEN = 300
|
9 |
+
TARGET_MAX_TOKEN_LEN = 80
|
10 |
+
SEP_TOKEN = '<sep>'
|
11 |
+
TOKENIZER_LEN = 32101
|
12 |
+
|
13 |
+
|
14 |
+
class QuestionAnswerGenerator():
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
self.tokenizer = T5TokenizerFast.from_pretrained(MODEL_NAME)
|
18 |
+
self.tokenizer.add_tokens(SEP_TOKEN)
|
19 |
+
self.tokenizer_len = len(self.tokenizer)
|
20 |
+
self.model = T5ForConditionalGeneration.from_pretrained("fahmiaziz/QAModel")
|
21 |
+
|
22 |
+
def generate(self, answer: str, context: str) -> str:
|
23 |
+
|
24 |
+
model_output = self._model_predict(answer, context)
|
25 |
+
generated_answer, generated_question = model_output.split(SEP_TOKEN)
|
26 |
+
return generated_question
|
27 |
+
|
28 |
+
def generate_qna(self, context: str) -> Tuple[str, str]:
|
29 |
+
|
30 |
+
answer_mask = '[MASK]'
|
31 |
+
model_output = self._model_predict(answer_mask, context)
|
32 |
+
|
33 |
+
qna_pair = model_output.split(SEP_TOKEN)
|
34 |
+
|
35 |
+
if len(qna_pair) < 2:
|
36 |
+
generated_answer = ''
|
37 |
+
generated_question = qna_pair[0]
|
38 |
+
else:
|
39 |
+
generated_answer = qna_pair[0]
|
40 |
+
generated_question = qna_pair[1]
|
41 |
+
|
42 |
+
return generated_answer, generated_question
|
43 |
+
|
44 |
+
def _model_predict(self, answer: str, context: str) -> str:
|
45 |
+
source_encoding = self.tokenizer(
|
46 |
+
'{} {} {}'.format(answer, SEP_TOKEN, context),
|
47 |
+
max_length=SOURCE_MAX_TOKEN_LEN,
|
48 |
+
padding='max_length',
|
49 |
+
truncation=True,
|
50 |
+
return_attention_mask=True,
|
51 |
+
add_special_tokens=True,
|
52 |
+
return_tensors='pt'
|
53 |
+
)
|
54 |
+
|
55 |
+
generated_ids = self.model.generate(
|
56 |
+
input_ids=source_encoding['input_ids'],
|
57 |
+
attention_mask=source_encoding['attention_mask'],
|
58 |
+
num_beams=16,
|
59 |
+
max_length=TARGET_MAX_TOKEN_LEN,
|
60 |
+
repetition_penalty=2.5,
|
61 |
+
length_penalty=1.0,
|
62 |
+
early_stopping=True,
|
63 |
+
use_cache=True
|
64 |
+
)
|
65 |
+
|
66 |
+
preds = {
|
67 |
+
self.tokenizer.decode(generated_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
68 |
+
for generated_id in generated_ids
|
69 |
+
}
|
70 |
+
|
71 |
+
return ''.join(preds)
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
nltk
|
3 |
+
transformers
|