fahmiaziz commited on
Commit
aeae383
·
verified ·
1 Parent(s): bc2c5d1
Files changed (7) hide show
  1. app.py +66 -0
  2. cleaned_text.py +58 -0
  3. distractor.py +78 -0
  4. models.py +7 -0
  5. pipeline.py +82 -0
  6. qagenerator.py +71 -0
  7. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import random
3
+ from pipeline import Pipeline
4
+
5
+ st.header("Generate Multiple Choice QA Generation")
6
+ st.markdown(
7
+ "I built this project based on this [paper](https://www.sciencedirect.com/science/article/pii/S0957417422014014#s0015), "
8
+ "where they created End-to-End generation of Multiple-Choice questions using Text-to-Text transfer Transformer models (T5).\n\n"
9
+ "This research focuses on using Transformer-based language models to automate the generation of multiple-choice questions (MCQs), "
10
+ "with the aim of assisting or assisting educators in the process of creating reading comprehension (RC) assessments. "
11
+ "This is relevant and timely as teachers can invest less time doing routine work and share more time with their students, "
12
+ "thus building an engaging experience for face-to-face classroom interaction. "
13
+ "This study addresses the issue of creating multiple-choice questionnaires from 3 viewpoints: QG, QA, and distractor generation (DG). "
14
+ "An end-to-end pipeline for generating multiple-choice questions is proposed, based on a pre-trained T5 language model."
15
+ )
16
+
17
+
18
+
19
+ st.sidebar.info(
20
+ "Note: The number of questions generated depends on the length of the context. "
21
+ "You may find that the number of QA pairs does not match the number you want."
22
+ )
23
+
24
+ with st.sidebar:
25
+ if "num_qa" not in st.session_state:
26
+ st.session_state.num_qa = 5
27
+
28
+ def on_change():
29
+ st.session_state.num_qa = num_qa
30
+
31
+ num_qa = st.slider("Select Number of QA questions", min_value=1, max_value=10, value=1, step=1, on_change=on_change)
32
+
33
+ if 'context' not in st.session_state:
34
+ st.session_state.context = ""
35
+ st_text_area = st.text_area('Context to generate the QA', value=st.session_state.context, height=500)
36
+
37
+ def generate_qa():
38
+ st.session_state.context = st_text_area
39
+ mcq_generator = Pipeline()
40
+ generator = mcq_generator.generate_mcqs(st_text_area, num_qa)
41
+ st.session_state.generator = generator
42
+
43
+ # generate qa button
44
+ st_generate_button = st.button('Generate', on_click=generate_qa)
45
+
46
+ # Display generated MCQs in Streamlit
47
+ if hasattr(st.session_state, 'generator') and len(st.session_state.generator) > 0:
48
+ st.subheader("Generated MCQs")
49
+ for i, question in enumerate(st.session_state.generator, start=1):
50
+ correct_answer = [question.answerText]
51
+ distractors_subset = question.distractors[:3] # Assuming you want 3 distractors
52
+ options = correct_answer + distractors_subset
53
+
54
+ # Shuffle options
55
+ random.shuffle(options)
56
+
57
+ options_with_labels = [{'label': chr(ord('A') + j), 'text': option} for j, option in enumerate(options)]
58
+
59
+ st.write(f'Number {i}: {question.questionText}')
60
+ for option in options_with_labels:
61
+ if option["text"] == correct_answer[0]:
62
+ st.write(f'<span style="color:green;">{option["label"]}. {option["text"]}</span>', unsafe_allow_html=True)
63
+ else:
64
+ st.write(f'{option["label"]}. {option["text"]}')
65
+ st.write('-------------------')
66
+
cleaned_text.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import string
3
+ from typing import List
4
+
5
+
6
+ def normalize_item(item) -> str:
7
+ """Lower text and remove punctuation, articles and extra whitespace."""
8
+ def remove_articles(text):
9
+ return re.sub(r'\b(a|an|the)\b', ' ', text)
10
+
11
+ def white_space_fix(text):
12
+ return ' '.join(text.split())
13
+
14
+ def remove_punc(text):
15
+ exclude = set(string.punctuation)
16
+ return ''.join(ch for ch in text if ch not in exclude)
17
+
18
+ def lower(text):
19
+ return text.lower()
20
+
21
+ return white_space_fix(remove_articles(remove_punc(lower(item))))
22
+
23
+
24
+ def remove_duplicates(items: List[str]) -> List[str]:
25
+ unique_items = []
26
+ normalized_unique_items = []
27
+
28
+ for item in items:
29
+ normalized_item = normalize_item(item)
30
+
31
+ if normalized_item not in normalized_unique_items:
32
+ unique_items.append(item)
33
+ normalized_unique_items.append(normalized_item)
34
+
35
+ return unique_items
36
+
37
+ def remove_distractors_duplicate_with_correct_answer(correct: str, distractors: List[str]) -> List[str]:
38
+ normalized_correct = normalize_item(correct)
39
+
40
+ filtered_distractors = []
41
+
42
+ for distractor in distractors:
43
+ if normalize_item(distractor) != normalized_correct:
44
+ filtered_distractors.append(distractor)
45
+
46
+ return filtered_distractors
47
+
48
+ def clean_text(text: str) -> str:
49
+ # remove brackets
50
+ cleaned_text = re.sub(r"\((.*?)\)", lambda L: "", text)
51
+ # remove square bracket
52
+ cleaned_text = re.sub(r"\[(.*?)\]", lambda L: "", cleaned_text)
53
+ # remove multiple space
54
+ cleaned_text = re.sub(" +", " ", cleaned_text)
55
+ # replace weird hypen
56
+ cleaned_text = cleaned_text.replace('–', '-')
57
+
58
+ return cleaned_text
distractor.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import T5TokenizerFast, T5ForConditionalGeneration
2
+ import string
3
+ from typing import List
4
+
5
+
6
+ SOURCE_MAX_TOKEN_LEN = 512
7
+ TARGET_MAX_TOKEN_LEN = 50
8
+ SEP_TOKEN = "[SEP]"
9
+ MODEL_NAME = "t5-small"
10
+
11
+ # Definisi kelas DistractorGenerator
12
+ class DistractorGenerator:
13
+ def __init__(self):
14
+ self.tokenizer = T5TokenizerFast.from_pretrained(MODEL_NAME)
15
+ self.tokenizer.add_tokens(SEP_TOKEN)
16
+ self.tokenizer_len = len(self.tokenizer)
17
+ self.model = T5ForConditionalGeneration.from_pretrained("fahmiaziz/QDModel")
18
+
19
+ def generate(self, generate_count: int, correct: str, question: str, context: str) -> List[str]:
20
+ model_output = self._model_predict(generate_count, correct, question, context)
21
+
22
+ cleaned_result = model_output.replace('<pad>', '').replace('</s>', ',')
23
+ cleaned_result = self._replace_all_extra_id(cleaned_result)
24
+ distractors = cleaned_result.split(",")[:-1]
25
+ distractors = [x.translate(str.maketrans('', '', string.punctuation)) for x in distractors]
26
+ distractors = list(map(lambda x: x.strip(), distractors))
27
+
28
+ return distractors
29
+
30
+ def _model_predict(self, generate_count: int, correct: str, question: str, context: str) -> str:
31
+ source_encoding = self.tokenizer(
32
+ '{} {} {} {} {}'.format(correct, SEP_TOKEN, question, SEP_TOKEN, context),
33
+ max_length=SOURCE_MAX_TOKEN_LEN,
34
+ padding='max_length',
35
+ truncation=True,
36
+ return_attention_mask=True,
37
+ add_special_tokens=True,
38
+ return_tensors='pt'
39
+ )
40
+
41
+ generated_ids = self.model.generate(
42
+ input_ids=source_encoding['input_ids'],
43
+ attention_mask=source_encoding['attention_mask'],
44
+ num_beams=generate_count,
45
+ num_return_sequences=generate_count,
46
+ max_length=TARGET_MAX_TOKEN_LEN,
47
+ repetition_penalty=2.5,
48
+ length_penalty=1.0,
49
+ early_stopping=True,
50
+ use_cache=True
51
+ )
52
+
53
+ preds = {
54
+ self.tokenizer.decode(generated_id, skip_special_tokens=False, clean_up_tokenization_spaces=True)
55
+ for generated_id in generated_ids
56
+ }
57
+
58
+ return ''.join(preds)
59
+
60
+ def _correct_index_of(self, text: str, substring: str, start_index: int = 0):
61
+ try:
62
+ index = text.index(substring, start_index)
63
+ except ValueError:
64
+ index = -1
65
+
66
+ return index
67
+
68
+ def _replace_all_extra_id(self, text: str):
69
+ new_text = text
70
+ start_index_of_extra_id = 0
71
+
72
+ while (self._correct_index_of(new_text, '<extra_id_') >= 0):
73
+ start_index_of_extra_id = self._correct_index_of(new_text, '<extra_id_', start_index_of_extra_id)
74
+ end_index_of_extra_id = self._correct_index_of(new_text, '>', start_index_of_extra_id)
75
+
76
+ new_text = new_text[:start_index_of_extra_id] + '[SEP]' + new_text[end_index_of_extra_id + 1:]
77
+
78
+ return new_text
models.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ class Question:
4
+ def __init__(self, answerText:str, questionText: str = '', distractors: List[str] = []):
5
+ self.answerText = answerText
6
+ self.questionText = questionText
7
+ self.distractors = distractors
pipeline.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from nltk.tokenize import sent_tokenize
3
+ import toolz
4
+
5
+ from models import Question
6
+ from cleaned_text import clean_text, remove_duplicates, remove_distractors_duplicate_with_correct_answer
7
+ from distractor import DistractorGenerator
8
+ from qagenerator import QuestionAnswerGenerator
9
+
10
+
11
+ class Pipeline:
12
+
13
+ def __init__(self):
14
+
15
+ self.question_generator = QuestionAnswerGenerator()
16
+ self.distractor_generator = DistractorGenerator()
17
+
18
+ # <======================= Main Function =============================>
19
+ def generate_mcqs(self, context: str, desired_count: int) -> List[Question]:
20
+
21
+ cleaned_text = clean_text(context)
22
+ questions = self._generate_question_answer_pairs(cleaned_text, desired_count)
23
+ questions = self._generate_distractors(cleaned_text, questions)
24
+
25
+ return questions
26
+ # <====================================================>
27
+
28
+
29
+ # number: 1
30
+ def _generate_question_answer_pairs(self, context: str, desired_count: int) -> List[Question]:
31
+ context_splits = self._split_context_according_to_desired_count(context, desired_count)
32
+
33
+ questions = []
34
+
35
+ for split in context_splits:
36
+ answer, question = self.question_generator.generate_qna(split)
37
+ questions.append(Question(answer.capitalize(), question))
38
+
39
+ questions = list(toolz.unique(questions, key=lambda x: x.answerText))
40
+
41
+ return questions
42
+
43
+ # number: 2
44
+ def _generate_distractors(self, context: str, questions: List[Question]) -> List[Question]:
45
+ for question in questions:
46
+ t5_distractors = self.distractor_generator.generate(5, question.answerText, question.questionText, context)
47
+
48
+ distractors = remove_duplicates(t5_distractors)
49
+ distractors = remove_distractors_duplicate_with_correct_answer(question.answerText, distractors)
50
+
51
+ #TODO - filter distractors having a similar bleu score with another distractor
52
+ # filter_distractors = []
53
+ # for dist in distractors:
54
+ # bleu_score = self._calculate_nltk_bleu([dist], question.answerText)
55
+ # if bleu_score > 0.1:
56
+ # filter_distractors.append(dist)
57
+ # <=================Need Improve Model=================>
58
+
59
+ question.distractors = distractors
60
+ return questions
61
+
62
+ # Helper functions
63
+ def _split_context_according_to_desired_count(self, context: str, desired_count: int) -> List[str]:
64
+ sents = sent_tokenize(context)
65
+ total_sents = len(sents)
66
+
67
+ if total_sents <= desired_count:
68
+ return sents # No need to split if the desired count is greater than or equal to the total sentences.
69
+
70
+ sentences_per_split = total_sents // desired_count
71
+ remainder = total_sents % desired_count # Handle the remaining sentences.
72
+
73
+ context_splits = []
74
+ start_sent_index = 0
75
+
76
+ for i in range(desired_count):
77
+ end_sent_index = start_sent_index + sentences_per_split + (1 if i < remainder else 0)
78
+ context_split = ' '.join(sents[start_sent_index:end_sent_index])
79
+ context_splits.append(context_split)
80
+ start_sent_index = end_sent_index
81
+
82
+ return context_splits
qagenerator.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Tuple
2
+ from transformers import T5TokenizerFast, T5ForConditionalGeneration
3
+ import string
4
+ from typing import List
5
+
6
+ # Constants
7
+ MODEL_NAME = 't5-small'
8
+ SOURCE_MAX_TOKEN_LEN = 300
9
+ TARGET_MAX_TOKEN_LEN = 80
10
+ SEP_TOKEN = '<sep>'
11
+ TOKENIZER_LEN = 32101
12
+
13
+
14
+ class QuestionAnswerGenerator():
15
+
16
+ def __init__(self):
17
+ self.tokenizer = T5TokenizerFast.from_pretrained(MODEL_NAME)
18
+ self.tokenizer.add_tokens(SEP_TOKEN)
19
+ self.tokenizer_len = len(self.tokenizer)
20
+ self.model = T5ForConditionalGeneration.from_pretrained("fahmiaziz/QAModel")
21
+
22
+ def generate(self, answer: str, context: str) -> str:
23
+
24
+ model_output = self._model_predict(answer, context)
25
+ generated_answer, generated_question = model_output.split(SEP_TOKEN)
26
+ return generated_question
27
+
28
+ def generate_qna(self, context: str) -> Tuple[str, str]:
29
+
30
+ answer_mask = '[MASK]'
31
+ model_output = self._model_predict(answer_mask, context)
32
+
33
+ qna_pair = model_output.split(SEP_TOKEN)
34
+
35
+ if len(qna_pair) < 2:
36
+ generated_answer = ''
37
+ generated_question = qna_pair[0]
38
+ else:
39
+ generated_answer = qna_pair[0]
40
+ generated_question = qna_pair[1]
41
+
42
+ return generated_answer, generated_question
43
+
44
+ def _model_predict(self, answer: str, context: str) -> str:
45
+ source_encoding = self.tokenizer(
46
+ '{} {} {}'.format(answer, SEP_TOKEN, context),
47
+ max_length=SOURCE_MAX_TOKEN_LEN,
48
+ padding='max_length',
49
+ truncation=True,
50
+ return_attention_mask=True,
51
+ add_special_tokens=True,
52
+ return_tensors='pt'
53
+ )
54
+
55
+ generated_ids = self.model.generate(
56
+ input_ids=source_encoding['input_ids'],
57
+ attention_mask=source_encoding['attention_mask'],
58
+ num_beams=16,
59
+ max_length=TARGET_MAX_TOKEN_LEN,
60
+ repetition_penalty=2.5,
61
+ length_penalty=1.0,
62
+ early_stopping=True,
63
+ use_cache=True
64
+ )
65
+
66
+ preds = {
67
+ self.tokenizer.decode(generated_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
68
+ for generated_id in generated_ids
69
+ }
70
+
71
+ return ''.join(preds)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit
2
+ nltk
3
+ transformers