ck46 commited on
Commit
e9d3f14
·
1 Parent(s): 4458a4d

Added Question-Answer generation pipeline

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +21 -6
  3. qg_pipeline.py +143 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.*~
app.py CHANGED
@@ -1,9 +1,20 @@
1
  import streamlit as st
 
 
 
 
 
 
2
 
3
  # Add a model selector to the sidebar
4
- model = st.sidebar.selectbox(
5
- 'Select Model',
6
- ('t5-base-squad-qa-qg', 't5-small-squad-qa-qg', 't5-base-hotpot-qa-qg', 't5-small-hotpot-qa-qg')
 
 
 
 
 
7
  )
8
 
9
  st.header('Question-Answer Generation')
@@ -11,13 +22,17 @@ st.write(f'Model in use: {model}')
11
 
12
  txt = st.text_area('Text for context')
13
 
 
 
 
 
 
 
14
 
15
  if len(txt) >= 1:
16
- autocards = []
17
  else:
18
  autocards = []
19
 
20
  st.header('Generated question and answers')
21
  st.write(autocards)
22
-
23
-
 
1
  import streamlit as st
2
+ from qg_pipeline import Pipeline
3
+
4
+ ## Load NLTK
5
+ import nltk
6
+ nltk.download('punkt')
7
+
8
 
9
  # Add a model selector to the sidebar
10
+ q_model = st.sidebar.selectbox(
11
+ 'Select Question Generation Model',
12
+ ('valhalla/t5-small-qg-hl', 'valhalla/t5-base-qg-hl', 'ck46/t5-base-squad-qa-qg', 'ck46/t5-small-squad-qa-qg', 'ck46/t5-base-hotpot-qa-qg', 'ck46/t5-small-hotpot-qa-qg')
13
+ )
14
+
15
+ a_model = st.sidebar.selectbox(
16
+ 'Select Answer Extraction Model',
17
+ ('valhalla/t5-small-qa-qg-hl', 'valhalla/t5-base-qa-qg-hl', 'ck46/t5-base-squad-qa-qg', 'ck46/t5-small-squad-qa-qg', 'ck46/t5-base-hotpot-qa-qg', 'ck46/t5-small-hotpot-qa-qg')
18
  )
19
 
20
  st.header('Question-Answer Generation')
 
22
 
23
  txt = st.text_area('Text for context')
24
 
25
+ pipeline = Pipeline(
26
+ q_model=q_model,
27
+ q_tokenizer=q_model,
28
+ a_model=a_model,
29
+ q_tokenizer=a_model
30
+ )
31
 
32
  if len(txt) >= 1:
33
+ autocards = pipeline(txt)
34
  else:
35
  autocards = []
36
 
37
  st.header('Generated question and answers')
38
  st.write(autocards)
 
 
qg_pipeline.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import (
3
+ AutoModelForSeq2SeqLM,
4
+ AutoTokenizer,
5
+ PreTrainedModel,
6
+ PreTrainedTokenizer,
7
+ )
8
+ from nltk import sent_tokenize
9
+
10
+ # Answer Extraction Handler
11
+ class AEHandler:
12
+ def __init__(self, model, tokenizer):
13
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model)
14
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
15
+ self.device = torch.device('gpu' if torch.cuda.is_available() else 'cpu')
16
+ self.model.to(self.device)
17
+
18
+ def __call__(self, context):
19
+ return self.inference(self.preprocess(context))
20
+
21
+ def preprocess(self, context):
22
+ sents = sent_tokenize(context)
23
+
24
+ inputs = []
25
+ for i in range(len(sents)):
26
+ source_text = "extract answers:"
27
+ for j, sent in enumerate(sents):
28
+ if i == j:
29
+ sent = "<hl> %s <hl>" % sent
30
+ source_text = "%s %s" % (source_text, sent)
31
+ source_text = source_text.strip()
32
+ source_text = source_text + " </s>"
33
+ inputs.append(source_text)
34
+
35
+ tokenized_inputs = self.tokenizer.batch_encode_plus(
36
+ inputs,
37
+ max_length=512,
38
+ add_special_tokens=True,
39
+ truncation=True,
40
+ padding="max_length",
41
+ pad_to_max_length=True,
42
+ return_tensors="pt"
43
+ )
44
+ return tokenized_inputs
45
+
46
+ def inference(self, inputs):
47
+ outs = self.model.generate(
48
+ input_ids=inputs['input_ids'].to(self.device),
49
+ attention_mask=inputs['attention_mask'].to(self.device),
50
+ max_length=32)
51
+
52
+ dec = [self.tokenizer.decode(ids, skip_special_tokens=False).replace('<pad> ', '').strip() for ids in outs]
53
+ answers = [item.split('<sep>')[:-1] for item in dec]
54
+ return answers
55
+
56
+ def postprocess(self, outputs):
57
+ return outputs
58
+
59
+
60
+ # Question Generation Handler
61
+ class QGHandler:
62
+ def __init__(self, model, tokenizer):
63
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model)
64
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
65
+ self.device = torch.device('gpu' if torch.cuda.is_available() else 'cpu')
66
+ self.model.to(device)
67
+
68
+ def __call__(self, answers, context):
69
+ tokenized_inputs = self.preprocess(answers, context)
70
+ return self.inference(tokenized_inputs)
71
+
72
+ def preprocess(self, answers, context):
73
+ # prepare inputs for question generation from answers
74
+ sents = sent_tokenize(context)
75
+ qg_examples = []
76
+ for i, answer in enumerate(answers):
77
+ if len(answer) == 0: continue
78
+ for answer_text in answer:
79
+ sent = sents[i]
80
+ sents_copy = sents[:]
81
+
82
+ answer_text = answer_text.strip()
83
+
84
+ try:
85
+ ans_start_idx = sent.index(answer_text)
86
+ except:
87
+ continue
88
+
89
+ sent = f"{sent[:ans_start_idx]} <hl> {answer_text} <hl> {sent[ans_start_idx + len(answer_text): ]}"
90
+ sents_copy[i] = sent
91
+
92
+ source_text = " ".join(sents_copy)
93
+ source_text = f"generate question: {source_text}"
94
+ #if self.model_type == "t5":
95
+ source_text = source_text + " </s>"
96
+ qg_examples.append({"answer": answer_text, "source_text": source_text})
97
+
98
+ # question generation inputs
99
+ qg_inputs = [example['source_text'] for example in qg_examples]
100
+
101
+ tokenized_inputs = self.tokenizer.batch_encode_plus(
102
+ qg_inputs,
103
+ max_length=512,
104
+ add_special_tokens=True,
105
+ truncation=True,
106
+ padding="max_length",
107
+ pad_to_max_length=True,
108
+ return_tensors="pt"
109
+ )
110
+ self.qg_examples = qg_examples
111
+ return tokenized_inputs
112
+
113
+ def inference(self, inputs):
114
+ outs = self.model.generate(
115
+ input_ids=inputs['input_ids'].to(self.device),
116
+ attention_mask=inputs['attention_mask'].to(self.device),
117
+ max_length=32,
118
+ num_beams=4,
119
+ )
120
+
121
+ questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
122
+ return questions
123
+
124
+ def postprocess(self, questions):
125
+ outputs = [{'question': que, 'answer': example['answer']} for example, que in zip(self.qg_examples, questions)]
126
+ return outputs
127
+
128
+
129
+ # Question-Answer Generation Pipeline
130
+ class Pipeline:
131
+ def __init__(self, q_model=None, q_tokenizer=None, a_model=None, a_tokenizer=None):
132
+ self.q_model = q_model if q_model is not None else "valhalla/t5-small-qg-hl"
133
+ self.q_tokenizer = q_tokenizer if q_tokenizer is not None else "valhalla/t5-small-qg-hl"
134
+ self.a_model = a_model if a_model is not None else "valhalla/t5-small-qa-qg-hl"
135
+ self.a_tokenizer = a_tokenizer if a_tokenizer is not None else "valhalla/t5-small-qa-qg-hl"
136
+
137
+ self.answer_extractor = AEHandler(self.a_model, self.a_tokenizer)
138
+ self.question_generator = QGHandler(self.q_model, self.q_tokenizer)
139
+
140
+ def __call__(self, context):
141
+ answers = self.answer_extractor(context)
142
+ questions = self.question_generator(answers, context)
143
+ return self.question_generator.postprocess(questions)