flynn-chen commited on
Commit
97ec4dd
·
1 Parent(s): 025d2c7
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: FeinbergQuizNotes
3
+ emoji: 💻
4
+ colorFrom: gray
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.0.26
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import pandas as pd
4
+ from question_generation.pipelines import pipeline
5
+ import docx2txt
6
+
7
+ qa_list = []
8
+ def process_file(Notes):
9
+
10
+ os.system("pip install -U transformers==3.0.0")
11
+ os.system("python -m nltk.downloader punkt")
12
+ nlp = pipeline("question-generation", model="valhalla/t5-small-qg-prepend", qg_format="prepend")
13
+
14
+ target_word_doc = Notes.name
15
+ raw_word_file = docx2txt.process(target_word_doc)
16
+
17
+ #remove empty lines
18
+ preprocessed_sentence_list = [i for i in raw_word_file.splitlines() if i != ""]
19
+
20
+ #grab content
21
+ processed_sentence_list = []
22
+ content = False
23
+ for i in preprocessed_sentence_list:
24
+ if "Outline" in i:
25
+ content = True
26
+ continue
27
+ if "Summary Learning Points" in i:
28
+ content = False
29
+ continue
30
+ if "Learning Activity" in i:
31
+ content = False
32
+ continue
33
+ if content == True:
34
+ processed_sentence_list.append(i.lstrip())
35
+
36
+ qa_list.extend(nlp(" ".join(processed_sentence_list)))
37
+ formatted_questions = "\n".join([str(idx+1) + ". " + i["question"] for idx, i in enumerate(qa_list)])
38
+ formatted_answers = "\n".join([str(idx+1) + ". " + i["answer"] for idx, i in enumerate(qa_list)])
39
+ return [formatted_questions, formatted_answers]
40
+
41
+ def reveal_answer():
42
+ global qa_list
43
+
44
+ qa_list = []
45
+ return formatted_answers
46
+
47
+ io = gr.Interface(process_file, "file", outputs=
48
+ [gr.Textbox(lines=1, label="Questions"),
49
+ gr.Textbox(lines=1, label="Answers")])
50
+ io.launch()
question_generation/README.md ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Question Generation using 🤗transformers
2
+
3
+ - [Question Generation using 🤗transformers](#question-generation-using-transformers)
4
+ - [Project Details](#project-details)
5
+ - [Initial experiments](#initial-experiments)
6
+ - [answer aware question generation](#answer-aware-question-generation)
7
+ - [answer extraction models](#answer-extraction-models)
8
+ - [Multitask QA-QG](#multitask-qa-qg)
9
+ - [End-to-End question generation (answer agnostic)](#end-to-end-question-generation-answer-agnostic)
10
+ - [Results](#results)
11
+ - [Requirements](#requirements)
12
+ - [Usage](#usage)
13
+ - [Question Generation](#question-generation)
14
+ - [Multitask QA-QG](#multitask-qa-qg-1)
15
+ - [End-to-end question generation (without answer supervision)](#end-to-end-question-generation-without-answer-supervision)
16
+ - [Fine-tuning](#fine-tuning)
17
+ - [Data processing](#data-processing)
18
+ - [training](#training)
19
+ - [Evaluation](#evaluation)
20
+ - [Applications 🚀](#applications-)
21
+ - [Relevant papers](#relevant-papers)
22
+
23
+
24
+ ## Project Details
25
+ Question generation is the task of automatically generating questions from a text paragraph. The most straight-forward way for this is answer aware question generation. In answer aware question generation the model is presented with the answer and the passage and asked to generate a question for that answer by considering the passage context. While there are many papers available for QG task, it's still not as mainstream as QA. One of the reasons is most of the earlier papers use complicated models/processing pipelines and have no pre-trained models available. Few recent papers, specifically UniLM and ProphetNet have SOTA pre-trained weights availble for QG but the usage seems quite complicated.
26
+
27
+ This project is aimed as an open source study on question generation with pre-trained transformers (specifically seq-2-seq models) using straight-forward end-to-end methods without much complicated pipelines. The goal is to provide simplified data processing and training scripts and easy to use pipelines for inference.
28
+
29
+
30
+ ## Initial experiments
31
+ Initial experiments are conducted using the SQuADv1 dataset and T5 model with different input processing formats as described below.
32
+
33
+ ### answer aware question generation
34
+
35
+ For answer aware models the input text can be processed in two ways.
36
+
37
+ **1. prepend format:**
38
+
39
+ Here the answer is simply added before the context and seperated by sep token. For example
40
+
41
+ `42 [SEP] 42 is the answer to life, the universe and everything.`
42
+
43
+ for T5 model the input is processed like this
44
+
45
+ `answer: 42 context: 42 is the answer to life, the universe and everything.`
46
+
47
+ **2. highlight format**
48
+
49
+ Here the answer span is highlighted within the text with special highlight tokens.
50
+
51
+ `<hl> 42 <hl> is the answer to life, the universe and everything.`
52
+
53
+ This idea is proposed in the "A Recurrent BERT-based Model for Question Generation" [paper](https://www.aclweb.org/anthology/D19-5821.pdf). See section 4.3
54
+
55
+ ### answer extraction models
56
+
57
+ As the answer aware models need answers for generating question, we need something which can extract answer like spans from the text. This can be done using various methods like NER, noun-phrase extarction etc. But here a model is trained to extract answer like spans, to see how it'll work. With T5, answer extarction is done using the text-to-format.
58
+
59
+ As the highlight format will need to know the position of extracted answer spans the input for answer extraction is processed as follows
60
+
61
+ 1. split the text into senteces.
62
+ 2. for each sentence that has answers, highlight the sentence with `<hl>` tokens.
63
+ 3. for the target text join the answers in that sentence with `<sep>` tokens.
64
+
65
+ For example for this text
66
+
67
+ `Python is a programming language. Created by Guido van Rossum and first released in 1991.`
68
+
69
+ following examples will be created
70
+
71
+ Input text:
72
+ `<hl> Python is a programming language. <hl> Created by Guido van Rossum and first released in 1991.`
73
+
74
+ target text:
75
+ `Python <sep>`
76
+
77
+ and
78
+
79
+ Input text:
80
+ `Python is a programming language. <hl> Created by Guido van Rossum and first released in 1991 <hl>.`
81
+
82
+ target text:
83
+ `Guido van Rossum <sep> 1991 <sep>`
84
+
85
+ At inference time the text is split into sentences and each sentence is highlighted.
86
+
87
+ ### Multitask QA-QG
88
+
89
+ For answer aware question generation we usually need 3 models, first which will extract answer like spans, second model will generate question on that answer and third will be a QA model which will take the question and produce an answer,
90
+ then we can compare the two answers to see if the generated question is correct or not.
91
+
92
+ Having 3 models for single task is lot of complexity, so goal is to create a multi-task model which can do all of these 3 tasks
93
+
94
+ 1. extract answer like spans
95
+ 2. generate question based on the answer
96
+ 3. QA
97
+
98
+ T5 model is fine-tuned in multi-task way using task prefixes as described in the paper.
99
+
100
+ <p align="center">
101
+ <img width="80%", src="https://i.ibb.co/TBS3nsr/t5-ss-2.png">
102
+ </p>
103
+
104
+ ### End-to-End question generation (answer agnostic)
105
+
106
+ In end-to-end question generation the model is aksed to generate questions without providing the answers. [This](https://arxiv.org/pdf/2005.01107v1.pdf) paper discusses these ideas in more detail. Here the T5 model is trained to generate multiple questions simultaneously by just providing the context. The questions are seperated by the `<sep>` token. Here's how the examples are processed
107
+
108
+ input text: `Python is a programming language. Created by Guido van Rossum and first released in 1991.`
109
+
110
+ target text: `Who created Python ? <sep> When was python released ? <sep>`
111
+
112
+ **All the training details can be found in [this](https://app.wandb.ai/psuraj/question-generation) wandb project**
113
+
114
+ ## Results
115
+
116
+ Results on the SQuAD1.0 dev set using above approaches. For decoding, beam search with num_beams 4 is used with max decoding length set to 32.
117
+
118
+ For multitask qa-qg models the EM and F1 scores are privded as QA-EM and QA-F1.
119
+
120
+ The [nlg-eval](https://github.com/Maluuba/nlg-eval) package is used for calculating the metrics.
121
+
122
+
123
+ | Name | BLEU-4 | METEOR | ROUGE-L | QA-EM | QA-F1 | QG-FORMAT |
124
+ |----------------------------------------------------------------------------|---------|---------|---------|--------|--------|-----------|
125
+ | [t5-base-qg-hl](https://huggingface.co/valhalla/t5-base-qg-hl) | 21.3226 | 27.0854 | 43.5962 | - | - | highlight |
126
+ | [t5-base-qa-qg-hl](https://huggingface.co/valhalla/t5-base-qa-qg-hl) | 21.0141 | 26.9113 | 43.2484 | 82.46 | 90.272 | highlight |
127
+ | [t5-small-qa-qg-hl](https://huggingface.co/valhalla/t5-small-qa-qg-hl) | 18.9872 | 25.2217 | 40.7893 | 76.121 | 84.904 | highlight |
128
+ | [t5-small-qg-hl](https://huggingface.co/valhalla/t5-small-qg-hl) | 18.5921 | 24.9915 | 40.1886 | - | - | highlight |
129
+ | [t5-small-qg-prepend](https://huggingface.co/valhalla/t5-small-qg-prepend) | 18.2791 | 24.6722 | 39.958 | - | - | prepend |
130
+
131
+
132
+ ## Requirements
133
+ ```
134
+ transformers==3.0.0
135
+ nltk
136
+ nlp==0.2.0 # only if you want to fine-tune.
137
+ ```
138
+
139
+ after installing `nltk` do
140
+ ```bash
141
+ python -m nltk.downloader punkt
142
+ ```
143
+
144
+ ## Usage
145
+ Use the pipeline whch mimics 🤗transformers pipeline for easy inference.
146
+
147
+ The pipeline is divided into 3 tasks
148
+ 1. `question-generation`: for single task question generation models.
149
+ 2. `multitask-qa-qg`: for multi-task qa,qg models.
150
+ 3. `e2e-qg`: for end-to-end question generation.
151
+
152
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/question_generation/blob/master/question_generation.ipynb)
153
+
154
+ #### Question Generation
155
+
156
+ ```python3
157
+ from pipelines import pipeline
158
+
159
+ nlp = pipeline("question-generation")
160
+ nlp("42 is the answer to life, the universe and everything.")
161
+ => [{'answer': '42', 'question': 'What is the answer to life, the universe and everything?'}]
162
+ ```
163
+
164
+ **prepend format**
165
+ ```python3
166
+ nlp = pipeline("question-generation", model="valhalla/t5-small-qg-prepend", qg_format="prepend")
167
+ nlp("42 is the answer to life, the universe and everything.")
168
+ => [{'answer': '42 ', 'question': 'What is the answer to life, the universe, and everything?'}]
169
+ ```
170
+
171
+ #### Multitask QA-QG
172
+ ```python3
173
+ nlp = pipeline("multitask-qa-qg")
174
+
175
+ # to generate questions simply pass the text
176
+ nlp("42 is the answer to life, the universe and everything.")
177
+ => [{'answer': '42', 'question': 'What is the answer to life, the universe and everything?'}]
178
+
179
+ # for qa pass a dict with "question" and "context"
180
+ nlp({
181
+ "question": "What is 42 ?",
182
+ "context": "42 is the answer to life, the universe and everything."
183
+ })
184
+ => 'the answer to life, the universe and everything'
185
+ ```
186
+
187
+ #### End-to-end question generation (without answer supervision)
188
+ ```python3
189
+ nlp = pipeline("e2e-qg")
190
+ nlp("Python is a programming language. Created by Guido van Rossum and first released in 1991.")
191
+ => [
192
+ 'What is a programming language?',
193
+ 'Who created Python?',
194
+ 'When was Python first released?'
195
+ ]
196
+ ```
197
+
198
+ By default both pipelines will use the t5-small* models, to use the other models pass the path through `model` paramter.
199
+
200
+ By default the `question-generation` pipeline will download the [valhalla/t5-small-qg-hl](https://huggingface.co/valhalla/t5-small-qg-hl) model with `highlight` qg format. If you want to use prepend format then provide the path to the prepend model and set `qg_format` to `"prepend"`. For extracting answer like spans it uses [valhalla/t5-small-qa-qg-hl](https://huggingface.co/valhalla/t5-small-qa-qg-hl) model, you can provide a different model through `ans_model` parameter.
201
+
202
+ The `multitask-qa-qg` model is for multitask models which can extract answer like spans, do qg and qa, so it won't need seperate `ans_model`. By default [valhalla/t5-small-qa-qg-hl](https://huggingface.co/valhalla/t5-small-qa-qg-hl) model is used with `highlight` format. If you want to use prepend format then provide the path to the prepend model and set `qg_format` to `"prepend"`
203
+
204
+ The `e2e-qg` pipeline is for end-to-end question generation. These models can generate multiple questions simultaneously without answer supervision. By default it uses [valhalla/t5-small-e2e-qg](https://huggingface.co/valhalla/t5-small-e2e-qg)
205
+
206
+ ## Fine-tuning
207
+
208
+ ### Data processing
209
+
210
+ To support different data formats the trainer expects pre-processed cached dataset, so you can process the data the way you want.
211
+ The cached dataset should be saved using `torch.save` and it should return a `dict` with `source_ids`, `target_ids`, `attention_mask` keys from `__getitem__`.
212
+
213
+ - `source_ids`: encoded source text
214
+ - `target_ids`: encoded target text
215
+ - `attention_mask`: attention mask for the `source_ids`
216
+
217
+ The `T2TDataCollator` takes care of preparing right `input_ids` and `labels`. It also trims the batches dynamically to remove excessive padding tokens, to speed up the training.
218
+
219
+ The `data/squad_multitask` containes the modifed SQuAD dataset for answer aware question generation (using both prepend and highlight formats), question answering (text-to-text), answer extraction and end-to-end question generation. This dataset can be loaded using the awesome 🤗`nlp` library, this makes processing very easy.
220
+
221
+ To process and cache the dataset use `prepare_data.py` script. It will load the correct tokenizer depending on the `model_type` argument. It adds two new tokens `<sep>` and `<hl>` to the tokenizer and saves it at `{model_type}_qg_tokenizer` path. You should pass this tokenizer to the fine-tuning script.
222
+
223
+ The datasets will be saved in `data/` directory. You should provide filenames using `train_file_name` and `valid_file_name` arguments.
224
+
225
+ **process data for single task question generation with highlight_qg_format**
226
+ ```bash
227
+ python prepare_data.py \
228
+ --task qg \
229
+ --model_type t5 \
230
+ --dataset_path data/squad_multitask/ \
231
+ --qg_format highlight_qg_format \
232
+ --max_source_length 512 \
233
+ --max_target_length 32 \
234
+ --train_file_name train_data_qg_hl_t5.pt \
235
+ --valid_file_name valid_data_qg_hl_t5.pt \
236
+ ```
237
+
238
+ **process data for multi-task qa-qg with highlight_qg_format**
239
+
240
+ `valid_for_qg_only` argument is used to decide if the validation set should only contain data for qg task. For my multi-task experiments I used validation data with only qg task so that the eval loss curve can be easly compared with other single task models
241
+
242
+ ```bash
243
+ python prepare_data.py \
244
+ --task multi \
245
+ --valid_for_qg_only \
246
+ --model_type t5 \
247
+ --dataset_path data/squad_multitask/ \
248
+ --qg_format highlight_qg_format \
249
+ --max_source_length 512 \
250
+ --max_target_length 32 \
251
+ --train_file_name train_data_qa_qg_hl_t5.pt \
252
+ --valid_file_name valid_data_qg_hl_t5.pt \
253
+ ```
254
+
255
+ **process dataset for end-to-end question generation**
256
+ ```bash
257
+ python prepare_data.py \
258
+ --task e2e_qg \
259
+ --valid_for_qg_only \
260
+ --model_type t5 \
261
+ --dataset_path data/squad_multitask/ \
262
+ --qg_format highlight_qg_format \
263
+ --max_source_length 512 \
264
+ --max_target_length 32 \
265
+ --train_file_name train_data_e2e_qg_t5.pt \
266
+ --valid_file_name valid_data_e2e_qg_t5.pt \
267
+ ```
268
+
269
+ ### training
270
+ Use the `run_qg.py` script to start training. It uses transformers `Trainer` class for training the models.
271
+
272
+
273
+ ```bash
274
+ python run_qg.py \
275
+ --model_name_or_path t5-small \
276
+ --model_type t5 \
277
+ --tokenizer_name_or_path t5_qg_tokenizer \
278
+ --output_dir t5-small-qg-hl \
279
+ --train_file_path data/train_data_qg_hl_t5.pt \
280
+ --valid_file_path data/valid_data_qg_hl_t5.pt \
281
+ --per_device_train_batch_size 32 \
282
+ --per_device_eval_batch_size 32 \
283
+ --gradient_accumulation_steps 8 \
284
+ --learning_rate 1e-4 \
285
+ --num_train_epochs 10 \
286
+ --seed 42 \
287
+ --do_train \
288
+ --do_eval \
289
+ --evaluate_during_training \
290
+ --logging_steps 100
291
+ ```
292
+
293
+ or if you want to train it from script or notebook then
294
+
295
+ ```python3
296
+ from run_qg import run_qg
297
+
298
+ args_dict = {
299
+ "model_name_or_path": "t5-small",
300
+ "model_type": "t5",
301
+ "tokenizer_name_or_path": "t5_qg_tokenizer",
302
+ "output_dir": "t5-small-qg-hl",
303
+ "train_file_path": "data/train_data_qg_hl_t5.pt",
304
+ "valid_file_path": "data/valid_data_qg_hl_t5.pt",
305
+ "per_device_train_batch_size": 32,
306
+ "per_device_eval_batch_size": 32,
307
+ "gradient_accumulation_steps": 8,
308
+ "learning_rate": 1e-4,
309
+ "num_train_epochs": 10,
310
+ "seed": 42,
311
+ "do_train": True,
312
+ "do_eval": True,
313
+ "evaluate_during_training": True,
314
+ "logging_steps": 100
315
+ }
316
+
317
+ # start training
318
+ run_qg(args_dict)
319
+ ```
320
+
321
+ ### Evaluation
322
+
323
+ Use the `eval.py` script for evaluting the model.
324
+
325
+ ```bash
326
+ python eval.py \
327
+ --model_name_or_path t5-base-qg-hl \
328
+ --valid_file_path valid_data_qg_hl_t5.pt \
329
+ --model_type t5 \
330
+ --num_beams 4 \
331
+ --max_decoding_length 32 \
332
+ --output_path hypothesis_t5-base-qg-hl.txt
333
+ ```
334
+
335
+ This will save the output at {output_path} file.
336
+
337
+ To calculate the metrics install the [nlg-eval](https://github.com/Maluuba/nlg-eval) package and run
338
+
339
+ ```bash
340
+ nlg-eval --hypothesis=hypothesis_t5-base-qg-hl.txt --references=data/references.txt --no-skipthoughts --no-glove
341
+ ```
342
+
343
+ ## Applications 🚀
344
+
345
+ 1. A simple Trivia Quiz on topics of your choice - <br/>
346
+ [Medium article](https://medium.com/@nvarshney97/using-the-latest-nlp-techniques-for-fun-98f31ce7b556) and its [Colab Notebook](https://colab.research.google.com/gist/nrjvarshney/39ed6c80e2fe293b9e7eca5bc3a45b7d/quiz.ipynb)
347
+ 2. [Autocards, Accelerating learning through machine-generated flashcards](https://paulbricman.com/docs/tools/autocards/)
348
+
349
+ ## Relevant papers
350
+ - https://arxiv.org/abs/1906.05416
351
+ - https://www.aclweb.org/anthology/D19-5821/
352
+ - https://arxiv.org/abs/2005.01107v1
question_generation/data/references.txt ADDED
The diff for this file is too large to render. See raw diff
 
question_generation/data/squad_multitask/dataset_infos.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"prepend_qg_format": {"description": "Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.\n", "citation": "@article{2016arXiv160605250R,\n author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},\n Konstantin and {Liang}, Percy},\n title = \"{SQuAD: 100,000+ Questions for Machine Comprehension of Text}\",\n journal = {arXiv e-prints},\n year = 2016,\n eid = {arXiv:1606.05250},\n pages = {arXiv:1606.05250},\narchivePrefix = {arXiv},\n eprint = {1606.05250},\n}\n", "homepage": "https://rajpurkar.github.io/SQuAD-explorer/", "license": "", "features": {"source_text": {"dtype": "string", "id": null, "_type": "Value"}, "target_text": {"dtype": "string", "id": null, "_type": "Value"}, "task": {"dtype": "string", "id": null, "_type": "Value"}}, "supervised_keys": null, "builder_name": "squad_multitask", "config_name": "prepend_qg_format", "version": {"version_str": "1.0.0", "description": "New split API (https://tensorflow.org/datasets/splits)", "nlp_version_to_prepare": null, "major": 1, "minor": 0, "patch": 0}, "splits": {"train": {"name": "train", "num_bytes": 225952922, "num_examples": 253276, "dataset_name": "squad_multitask"}, "validation": {"name": "validation", "num_bytes": 27650081, "num_examples": 30020, "dataset_name": "squad_multitask"}}, "download_checksums": {"https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json": {"num_bytes": 30288272, "checksum": "3527663986b8295af4f7fcdff1ba1ff3f72d07d61a20f487cb238a6ef92fd955"}, "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json": {"num_bytes": 4854279, "checksum": "95aa6a52d5d6a735563366753ca50492a658031da74f301ac5238b03966972c9"}}, "download_size": 35142551, "dataset_size": 253603003, "size_in_bytes": 288745554}, "highlight_qg_format": {"description": "Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.\n", "citation": "@article{2016arXiv160605250R,\n author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},\n Konstantin and {Liang}, Percy},\n title = \"{SQuAD: 100,000+ Questions for Machine Comprehension of Text}\",\n journal = {arXiv e-prints},\n year = 2016,\n eid = {arXiv:1606.05250},\n pages = {arXiv:1606.05250},\narchivePrefix = {arXiv},\n eprint = {1606.05250},\n}\n", "homepage": "https://rajpurkar.github.io/SQuAD-explorer/", "license": "", "features": {"source_text": {"dtype": "string", "id": null, "_type": "Value"}, "target_text": {"dtype": "string", "id": null, "_type": "Value"}, "task": {"dtype": "string", "id": null, "_type": "Value"}}, "supervised_keys": null, "builder_name": "squad_multitask", "config_name": "highlight_qg_format", "version": {"version_str": "1.0.0", "description": "New split API (https://tensorflow.org/datasets/splits)", "nlp_version_to_prepare": null, "major": 1, "minor": 0, "patch": 0}, "splits": {"train": {"name": "train", "num_bytes": 226286197, "num_examples": 253276, "dataset_name": "squad_multitask"}, "validation": {"name": "validation", "num_bytes": 27698388, "num_examples": 30020, "dataset_name": "squad_multitask"}}, "download_checksums": {"https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json": {"num_bytes": 30288272, "checksum": "3527663986b8295af4f7fcdff1ba1ff3f72d07d61a20f487cb238a6ef92fd955"}, "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json": {"num_bytes": 4854279, "checksum": "95aa6a52d5d6a735563366753ca50492a658031da74f301ac5238b03966972c9"}}, "download_size": 35142551, "dataset_size": 253984585, "size_in_bytes": 289127136}, "prepend_highlight_qg_format": {"description": "Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.\n", "citation": "@article{2016arXiv160605250R,\n author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},\n Konstantin and {Liang}, Percy},\n title = \"{SQuAD: 100,000+ Questions for Machine Comprehension of Text}\",\n journal = {arXiv e-prints},\n year = 2016,\n eid = {arXiv:1606.05250},\n pages = {arXiv:1606.05250},\narchivePrefix = {arXiv},\n eprint = {1606.05250},\n}\n", "homepage": "https://rajpurkar.github.io/SQuAD-explorer/", "license": "", "features": {"source_text": {"dtype": "string", "id": null, "_type": "Value"}, "target_text": {"dtype": "string", "id": null, "_type": "Value"}, "task": {"dtype": "string", "id": null, "_type": "Value"}}, "supervised_keys": null, "builder_name": "squad_multitask", "config_name": "prepend_highlight_qg_format", "version": {"version_str": "1.0.0", "description": "New split API (https://tensorflow.org/datasets/splits)", "nlp_version_to_prepare": null, "major": 1, "minor": 0, "patch": 0}, "splits": {"train": {"name": "train", "num_bytes": 227967699, "num_examples": 253276, "dataset_name": "squad_multitask"}, "validation": {"name": "validation", "num_bytes": 27893191, "num_examples": 30020, "dataset_name": "squad_multitask"}}, "download_checksums": {"https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json": {"num_bytes": 30288272, "checksum": "3527663986b8295af4f7fcdff1ba1ff3f72d07d61a20f487cb238a6ef92fd955"}, "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json": {"num_bytes": 4854279, "checksum": "95aa6a52d5d6a735563366753ca50492a658031da74f301ac5238b03966972c9"}}, "download_size": 35142551, "dataset_size": 255860890, "size_in_bytes": 291003441}}
question_generation/data/squad_multitask/dummy/plain_text/1.0.0/dummy_data.zip ADDED
Binary file (1.5 kB). View file
 
question_generation/data/squad_multitask/dummy/plain_text/1.0.0/dummy_data.zip.lock ADDED
File without changes
question_generation/data/squad_multitask/dummy/plain_text/1.0.0/dummy_data/dev ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "data": [
3
+ { "title": "dev test",
4
+ "paragraphs": [
5
+ { "context": "This is a test context.",
6
+ "qas": [
7
+ { "question": "Is this a test?",
8
+ "id": "2",
9
+ "answers": [
10
+ { "answer_start": 6,
11
+ "text": "This is a test text"
12
+ }
13
+ ]
14
+ }
15
+ ]
16
+ }
17
+ ]
18
+ }
19
+ ]
20
+ }
21
+
22
+
23
+
24
+
25
+
question_generation/data/squad_multitask/dummy/plain_text/1.0.0/dummy_data/train ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "data": [
3
+ { "title": "train test",
4
+ "paragraphs": [
5
+ { "context": "This is a test context.",
6
+ "qas": [
7
+ { "question": "Is this a test?",
8
+ "id": "1",
9
+ "answers": [
10
+ { "answer_start": 1,
11
+ "text": "This is a test text"
12
+ }
13
+ ]
14
+ }
15
+ ]
16
+ }
17
+ ]
18
+ }
19
+ ]
20
+ }
21
+
22
+
23
+
24
+
25
+
question_generation/data/squad_multitask/squad_multitask.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace NLP Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Lint as: python3
17
+ """SQUAD: The Stanford Question Answering Dataset."""
18
+
19
+ from __future__ import absolute_import, division, print_function
20
+
21
+ import json
22
+ import logging
23
+ import os
24
+
25
+ import nltk
26
+ nltk.download('punkt')
27
+
28
+ import nlp
29
+
30
+
31
+ _CITATION = """\
32
+ @article{2016arXiv160605250R,
33
+ author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},
34
+ Konstantin and {Liang}, Percy},
35
+ title = "{SQuAD: 100,000+ Questions for Machine Comprehension of Text}",
36
+ journal = {arXiv e-prints},
37
+ year = 2016,
38
+ eid = {arXiv:1606.05250},
39
+ pages = {arXiv:1606.05250},
40
+ archivePrefix = {arXiv},
41
+ eprint = {1606.05250},
42
+ }
43
+ """
44
+
45
+ _DESCRIPTION = """\
46
+ Stanford Question Answering Dataset (SQuAD) is a reading comprehension \
47
+ dataset, consisting of questions posed by crowdworkers on a set of Wikipedia \
48
+ articles, where the answer to every question is a segment of text, or span, \
49
+ from the corresponding reading passage, or the question might be unanswerable.
50
+ """
51
+
52
+ QG_FORMATS = [
53
+ "prepend",
54
+ "highlight",
55
+ "prepend_highlight",
56
+ ]
57
+
58
+
59
+ class SquadMultitaskConfig(nlp.BuilderConfig):
60
+ """BuilderConfig for SQUAD."""
61
+
62
+ def __init__(self, qg_format="highlight", **kwargs):
63
+ """BuilderConfig for SQUAD.
64
+
65
+ Args:
66
+ **kwargs: keyword arguments forwarded to super.
67
+ """
68
+ super(SquadMultitaskConfig, self).__init__(**kwargs)
69
+ self.qg_format = qg_format
70
+
71
+
72
+ class SquadMultitask(nlp.GeneratorBasedBuilder):
73
+ """SQUAD: The Stanford Question Answering Dataset. Version 1.1."""
74
+
75
+ _URL = "https://rajpurkar.github.io/SQuAD-explorer/dataset/"
76
+ _DEV_FILE = "dev-v1.1.json"
77
+ _TRAINING_FILE = "train-v1.1.json"
78
+
79
+ BUILDER_CONFIGS = [
80
+ SquadMultitaskConfig(
81
+ name=f"{format_}_qg_format",
82
+ version=nlp.Version("1.0.0", "New split API (https://tensorflow.org/datasets/splits)"),
83
+ description="Plain text",
84
+ qg_format=format_
85
+ )
86
+ for format_ in QG_FORMATS
87
+ ]
88
+
89
+ def _info(self):
90
+ return nlp.DatasetInfo(
91
+ description=_DESCRIPTION,
92
+ features=nlp.Features(
93
+ {
94
+ "source_text": nlp.Value("string"),
95
+ "target_text": nlp.Value("string"),
96
+ "task": nlp.Value("string"),
97
+ }
98
+ ),
99
+ # No default supervised_keys (as we have to pass both question
100
+ # and context as input).
101
+ supervised_keys=None,
102
+ homepage="https://rajpurkar.github.io/SQuAD-explorer/",
103
+ citation=_CITATION,
104
+ )
105
+
106
+ def _split_generators(self, dl_manager):
107
+ urls_to_download = {
108
+ "train": os.path.join(self._URL, self._TRAINING_FILE),
109
+ "dev": os.path.join(self._URL, self._DEV_FILE),
110
+ }
111
+ downloaded_files = dl_manager.download_and_extract(urls_to_download)
112
+
113
+ return [
114
+ nlp.SplitGenerator(name=nlp.Split.TRAIN, gen_kwargs={"filepath": downloaded_files["train"]}),
115
+ nlp.SplitGenerator(name=nlp.Split.VALIDATION, gen_kwargs={"filepath": downloaded_files["dev"]}),
116
+ ]
117
+
118
+ def _get_correct_alignement(self, context, answer):
119
+ """ Some original examples in SQuAD have indices wrong by 1 or 2 character. We test and fix this here. """
120
+ gold_text = answer['text']
121
+ start_idx = answer['answer_start']
122
+ end_idx = start_idx + len(gold_text)
123
+ if context[start_idx:end_idx] == gold_text:
124
+ return start_idx, end_idx # When the gold label position is good
125
+ elif context[start_idx-1:end_idx-1] == gold_text:
126
+ return start_idx-1, end_idx-1 # When the gold label is off by one character
127
+ elif context[start_idx-2:end_idx-2] == gold_text:
128
+ return start_idx-2, end_idx-2 # When the gold label is off by two character
129
+ else:
130
+ raise ValueError()
131
+
132
+ def process_qa_text(self, context, question, answer):
133
+ ans_gen_input = f"question: {question} context: {context}"
134
+ ans_gen_target = f"{answer}"
135
+ return {"source_text": ans_gen_input, "target_text": ans_gen_target, "task": "qa"}
136
+
137
+ def process_qg_text(self, context, question, answer):
138
+ answer_text = answer['text'].strip()
139
+
140
+ if self.config.qg_format == "prepend":
141
+ que_gen_input = f"answer: {answer_text} context: {context}"
142
+ elif self.config.qg_format == "highlight":
143
+ start_pos, end_pos = self._get_correct_alignement(context, answer)
144
+ que_gen_input = f"generate question: {context[:start_pos]} {{hl_token}} {answer_text} {{hl_token}} {context[end_pos:]}"
145
+ else:
146
+ start_pos, end_pos = self._get_correct_alignement(context, answer)
147
+ que_gen_input = f"answer: {answer_text} context: {context[:start_pos]} {{hl_token}} {answer_text} {{hl_token}} {context[end_pos:]}"
148
+
149
+ que_gen_target = f"{question}"
150
+ return {"source_text": que_gen_input, "target_text": que_gen_target, "task": "qg"}
151
+
152
+ def process_e2e_qg(self, paragraph):
153
+ source_text = f"generate questions: {paragraph['context'].strip()}"
154
+ questions = [qas['question'].strip() for qas in paragraph['qas']]
155
+ target_text = " {sep_token} ".join(questions)
156
+ target_text = f"{target_text} {{sep_token}}"
157
+ return {"source_text": source_text, "target_text": target_text, "task": "e2e_qg"}
158
+
159
+ def process_ans_ext(self, paragraph):
160
+ context = paragraph['context'].strip()
161
+
162
+ # split into sentences
163
+ sents = nltk.sent_tokenize(context)
164
+
165
+ # get positions of the sentences
166
+ positions = []
167
+ for i, sent in enumerate(sents):
168
+ if i == 0:
169
+ start, end = 0, len(sent)
170
+ else:
171
+ start, end = (prev_end + 1), (prev_end + len(sent) + 1)
172
+ prev_end = end
173
+ positions.append({'start': start, 'end': end})
174
+
175
+ # get answers
176
+ answers = [qa['answers'][0] for qa in paragraph['qas']]
177
+
178
+ # get list of answers for each sentence
179
+ sent_answers = []
180
+ for pos, sent in zip(positions, sents):
181
+ target_answers = []
182
+ for ans in answers:
183
+ if ans['answer_start'] in range(pos['start'], pos['end']):
184
+ target_answers.append(ans['text'].strip())
185
+ sent_answers.append(target_answers)
186
+
187
+ # build inputs and targets
188
+ examples = []
189
+ for i, ans in enumerate(sent_answers):
190
+ context = "extract answers:"
191
+ if len(ans) == 0: continue
192
+ ans = list(set(ans))
193
+ for j, sent in enumerate(sents):
194
+ if i == j:
195
+ sent = "{hl_token} %s {hl_token}" % sent
196
+ context = "%s %s" % (context, sent)
197
+ context = context.strip()
198
+ input_text = context
199
+ target_text = " {sep_token} ".join(ans) + " {sep_token}"
200
+
201
+ examples.append({'source_text': input_text, "target_text": target_text, "task": "ans_ext"})
202
+
203
+ return examples
204
+
205
+ def _generate_examples(self, filepath):
206
+ """This function returns the examples in the raw (text) form."""
207
+ logging.info("generating examples from = %s", filepath)
208
+ count = 0
209
+ tasks = ['qa', 'qg', 'ans_ext', 'e2e_qg']
210
+ with open(filepath) as f:
211
+ squad = json.load(f)
212
+ for article in squad["data"]:
213
+ title = article.get("title", "").strip()
214
+ for paragraph in article["paragraphs"]:
215
+ context = paragraph["context"].strip()
216
+
217
+ if 'ans_ext' in tasks:
218
+ ans_ext_examples = self.process_ans_ext(paragraph)
219
+ for example in ans_ext_examples:
220
+ yield count, example
221
+ count += 1
222
+
223
+ if 'e2e_qg' in tasks:
224
+ yield count, self.process_e2e_qg(paragraph)
225
+ count += 1
226
+
227
+ for qa in paragraph["qas"]:
228
+ question = qa["question"].strip()
229
+ id_ = qa["id"]
230
+
231
+ answers = [answer["text"].strip() for answer in qa["answers"]]
232
+ for task in tasks:
233
+ if task == 'qa':
234
+ yield count, self.process_qa_text(context, question, answers[0])
235
+ count += 1
236
+
237
+ if task == 'qg':
238
+ yield count, self.process_qg_text(context, question, qa["answers"][0])
239
+ count += 1
question_generation/data/squad_multitask/squad_multitask.py.lock ADDED
File without changes
question_generation/data_collator.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+
3
+ import torch
4
+
5
+
6
+ def trim_batch(
7
+ input_ids, pad_token_id, attention_mask=None,
8
+ ):
9
+ """Remove columns that are populated exclusively by pad_token_id"""
10
+ keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
11
+ if attention_mask is None:
12
+ return input_ids[:, keep_column_mask]
13
+ else:
14
+ return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
15
+
16
+
17
+ # prepares lm_labels from target_ids, returns examples with keys as expected by the forward method
18
+ # this is necessacry because the trainer directly passes this dict as arguments to the model
19
+ # so make sure the keys match the parameter names of the forward method
20
+ class T2TDataCollator():
21
+ def __init__(self, tokenizer, model_type="t5", mode='training', using_tpu=False):
22
+ self.tokenizer = tokenizer
23
+ self.model_type = model_type
24
+ self.mode = mode
25
+ self.using_tpu = using_tpu
26
+
27
+ def __call__(self, batch: List) -> Dict[str, torch.Tensor]:
28
+ """
29
+ Take a list of samples from a Dataset and collate them into a batch.
30
+ Returns:
31
+ A dictionary of tensors
32
+ """
33
+ input_ids = torch.stack([example['source_ids'] for example in batch])
34
+ target_ids = torch.stack([example['target_ids'] for example in batch])
35
+ attention_mask = torch.stack([example['attention_mask'] for example in batch])
36
+
37
+ pad_token_id = self.tokenizer.pad_token_id
38
+
39
+ # don't trim on tpu, for some reason trimming leads to slower training on TPU
40
+ if not self.using_tpu:
41
+ input_ids, attention_mask = trim_batch(input_ids, pad_token_id, attention_mask=attention_mask)
42
+ target_ids = trim_batch(target_ids, pad_token_id)
43
+
44
+ if self.model_type == "t5":
45
+ lm_labels = target_ids.clone()
46
+ decoder_input_ids = self._shift_right_t5(lm_labels)
47
+ if self.mode == 'training':
48
+ lm_labels[lm_labels[:, :] == pad_token_id] = -100
49
+ else:
50
+ decoder_input_ids = target_ids[:, :-1].contiguous()
51
+ lm_labels = target_ids[:, 1:].clone()
52
+ if self.mode == 'training':
53
+ lm_labels[target_ids[:, 1:] == pad_token_id] = -100
54
+
55
+ params = {
56
+ "input_ids": input_ids,
57
+ "attention_mask": attention_mask,
58
+ "labels": lm_labels,
59
+ "decoder_input_ids": decoder_input_ids
60
+ }
61
+
62
+ return params
63
+
64
+ def _shift_right_t5(self, input_ids):
65
+ decoder_start_token_id = self.tokenizer.pad_token_id
66
+ pad_token_id = self.tokenizer.pad_token_id
67
+
68
+ assert (
69
+ decoder_start_token_id is not None
70
+ ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
71
+
72
+ # shift inputs to the right
73
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
74
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
75
+ shifted_input_ids[..., 0] = decoder_start_token_id
76
+
77
+ assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
78
+ # replace possible -100 values in labels by `pad_token_id`
79
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
80
+
81
+ assert torch.all(shifted_input_ids >= 0).item(), "Verify that `labels` has only positive values and -100"
82
+
83
+ return shifted_input_ids
question_generation/eval.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass, field
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from tqdm.auto import tqdm
7
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser
8
+
9
+ from data_collator import T2TDataCollator
10
+
11
+ device = 'cuda' if torch.cuda.is_available else 'cpu'
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ @dataclass
16
+ class EvalArguments:
17
+ model_name_or_path: str = field(
18
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
19
+ )
20
+ valid_file_path: str = field(
21
+ metadata={"help": "Path for cached valid dataset"}
22
+ )
23
+ model_type: str = field(metadata={"help": "One of 't5', 'bart'"})
24
+ tokenizer_name_or_path: Optional[str] = field(
25
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
26
+ )
27
+ num_beams: Optional[int] = field(
28
+ default=4,
29
+ metadata={"help": "num_beams to use for decoding"}
30
+ )
31
+ max_decoding_length: Optional[int] = field(
32
+ default=32,
33
+ metadata={"help": "maximum length for decoding"}
34
+ )
35
+ output_path: Optional[str] = field(
36
+ default="hypothesis.txt",
37
+ metadata={"help": "path to save the generated questions."}
38
+ )
39
+
40
+ def get_predictions(model, tokenizer, data_loader, num_beams=4, max_length=32, length_penalty=1):
41
+ model.to(device)
42
+
43
+ predictions = []
44
+ model.eval()
45
+ with torch.no_grad():
46
+ for batch in tqdm(data_loader):
47
+ outs = model.generate(
48
+ input_ids=batch['input_ids'].to(device),
49
+ attention_mask=batch['attention_mask'].to(device),
50
+ num_beams=num_beams,
51
+ max_length=max_length,
52
+ length_penalty=length_penalty,
53
+ )
54
+
55
+ prediction = [tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
56
+ predictions.extend(prediction)
57
+
58
+ return predictions
59
+
60
+ def main():
61
+ parser = HfArgumentParser((EvalArguments,))
62
+ args = parser.parse_args_into_dataclasses()[0]
63
+
64
+ tokenizer = AutoTokenizer.from_pretrained(
65
+ args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_name_or_path,
66
+ )
67
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
68
+
69
+ valid_dataset = torch.load(args.valid_file_path)
70
+ collator = T2TDataCollator(
71
+ tokenizer=tokenizer,
72
+ model_type=args.model_type,
73
+ mode="inference"
74
+ )
75
+ loader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, collate_fn=collator)
76
+
77
+ predictions = get_predictions(
78
+ model=model,
79
+ tokenizer=tokenizer,
80
+ data_loader=loader,
81
+ num_beams=args.num_beams,
82
+ max_length=args.max_decoding_length
83
+ )
84
+
85
+ with open(args.output_path, 'w') as f:
86
+ f.write("\n".join(predictions))
87
+
88
+ logging.info(f"Output saved at {args.output_path}")
89
+
90
+
91
+ if __name__ == "__main__":
92
+ main()
question_generation/notebooks/question_generation.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
question_generation/pipelines.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import logging
3
+ from typing import Optional, Dict, Union
4
+
5
+ from nltk import sent_tokenize
6
+
7
+ import torch
8
+ from transformers import(
9
+ AutoModelForSeq2SeqLM,
10
+ AutoTokenizer,
11
+ PreTrainedModel,
12
+ PreTrainedTokenizer,
13
+ )
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class QGPipeline:
18
+ """Poor man's QG pipeline"""
19
+ def __init__(
20
+ self,
21
+ model: PreTrainedModel,
22
+ tokenizer: PreTrainedTokenizer,
23
+ ans_model: PreTrainedModel,
24
+ ans_tokenizer: PreTrainedTokenizer,
25
+ qg_format: str,
26
+ use_cuda: bool
27
+ ):
28
+ self.model = model
29
+ self.tokenizer = tokenizer
30
+
31
+ self.ans_model = ans_model
32
+ self.ans_tokenizer = ans_tokenizer
33
+
34
+ self.qg_format = qg_format
35
+
36
+ self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
37
+ self.model.to(self.device)
38
+
39
+ if self.ans_model is not self.model:
40
+ self.ans_model.to(self.device)
41
+
42
+ assert self.model.__class__.__name__ in ["T5ForConditionalGeneration", "BartForConditionalGeneration"]
43
+
44
+ if "T5ForConditionalGeneration" in self.model.__class__.__name__:
45
+ self.model_type = "t5"
46
+ else:
47
+ self.model_type = "bart"
48
+
49
+ def __call__(self, inputs: str):
50
+ inputs = " ".join(inputs.split())
51
+ sents, answers = self._extract_answers(inputs)
52
+ flat_answers = list(itertools.chain(*answers))
53
+
54
+ if len(flat_answers) == 0:
55
+ return []
56
+
57
+ if self.qg_format == "prepend":
58
+ qg_examples = self._prepare_inputs_for_qg_from_answers_prepend(inputs, answers)
59
+ else:
60
+ qg_examples = self._prepare_inputs_for_qg_from_answers_hl(sents, answers)
61
+
62
+ qg_inputs = [example['source_text'] for example in qg_examples]
63
+ questions = self._generate_questions(qg_inputs)
64
+ output = [{'answer': example['answer'], 'question': que} for example, que in zip(qg_examples, questions)]
65
+ return output
66
+
67
+ def _generate_questions(self, inputs):
68
+ inputs = self._tokenize(inputs, padding=True, truncation=True)
69
+
70
+ outs = self.model.generate(
71
+ input_ids=inputs['input_ids'].to(self.device),
72
+ attention_mask=inputs['attention_mask'].to(self.device),
73
+ max_length=32,
74
+ num_beams=4,
75
+ )
76
+
77
+ questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
78
+ return questions
79
+
80
+ def _extract_answers(self, context):
81
+ sents, inputs = self._prepare_inputs_for_ans_extraction(context)
82
+ inputs = self._tokenize(inputs, padding=True, truncation=True)
83
+
84
+ outs = self.ans_model.generate(
85
+ input_ids=inputs['input_ids'].to(self.device),
86
+ attention_mask=inputs['attention_mask'].to(self.device),
87
+ max_length=32,
88
+ )
89
+
90
+ dec = [self.ans_tokenizer.decode(ids, skip_special_tokens=False) for ids in outs]
91
+ answers = [item.split('<sep>') for item in dec]
92
+ answers = [i[:-1] for i in answers]
93
+
94
+ return sents, answers
95
+
96
+ def _tokenize(self,
97
+ inputs,
98
+ padding=True,
99
+ truncation=True,
100
+ add_special_tokens=True,
101
+ max_length=512
102
+ ):
103
+ inputs = self.tokenizer.batch_encode_plus(
104
+ inputs,
105
+ max_length=max_length,
106
+ add_special_tokens=add_special_tokens,
107
+ truncation=truncation,
108
+ padding="max_length" if padding else False,
109
+ pad_to_max_length=padding,
110
+ return_tensors="pt"
111
+ )
112
+ return inputs
113
+
114
+ def _prepare_inputs_for_ans_extraction(self, text):
115
+ sents = sent_tokenize(text)
116
+
117
+ inputs = []
118
+ for i in range(len(sents)):
119
+ source_text = "extract answers:"
120
+ for j, sent in enumerate(sents):
121
+ if i == j:
122
+ sent = "<hl> %s <hl>" % sent
123
+ source_text = "%s %s" % (source_text, sent)
124
+ source_text = source_text.strip()
125
+
126
+ if self.model_type == "t5":
127
+ source_text = source_text + " </s>"
128
+ inputs.append(source_text)
129
+
130
+ return sents, inputs
131
+
132
+ def _prepare_inputs_for_qg_from_answers_hl(self, sents, answers):
133
+ inputs = []
134
+ for i, answer in enumerate(answers):
135
+ if len(answer) == 0: continue
136
+ for answer_text in answer:
137
+ sent = sents[i]
138
+ sents_copy = sents[:]
139
+
140
+ answer_text = answer_text.strip()
141
+
142
+ ans_start_idx = sent.index(answer_text)
143
+
144
+ sent = f"{sent[:ans_start_idx]} <hl> {answer_text} <hl> {sent[ans_start_idx + len(answer_text): ]}"
145
+ sents_copy[i] = sent
146
+
147
+ source_text = " ".join(sents_copy)
148
+ source_text = f"generate question: {source_text}"
149
+ if self.model_type == "t5":
150
+ source_text = source_text + " </s>"
151
+
152
+ inputs.append({"answer": answer_text, "source_text": source_text})
153
+
154
+ return inputs
155
+
156
+ def _prepare_inputs_for_qg_from_answers_prepend(self, context, answers):
157
+ flat_answers = list(itertools.chain(*answers))
158
+ examples = []
159
+ for answer in flat_answers:
160
+ source_text = f"answer: {answer} context: {context}"
161
+ if self.model_type == "t5":
162
+ source_text = source_text + " </s>"
163
+
164
+ examples.append({"answer": answer, "source_text": source_text})
165
+ return examples
166
+
167
+
168
+ class MultiTaskQAQGPipeline(QGPipeline):
169
+ def __init__(self, **kwargs):
170
+ super().__init__(**kwargs)
171
+
172
+ def __call__(self, inputs: Union[Dict, str]):
173
+ if type(inputs) is str:
174
+ # do qg
175
+ return super().__call__(inputs)
176
+ else:
177
+ # do qa
178
+ return self._extract_answer(inputs["question"], inputs["context"])
179
+
180
+ def _prepare_inputs_for_qa(self, question, context):
181
+ source_text = f"question: {question} context: {context}"
182
+ if self.model_type == "t5":
183
+ source_text = source_text + " </s>"
184
+ return source_text
185
+
186
+ def _extract_answer(self, question, context):
187
+ source_text = self._prepare_inputs_for_qa(question, context)
188
+ inputs = self._tokenize([source_text], padding=False)
189
+
190
+ outs = self.model.generate(
191
+ input_ids=inputs['input_ids'].to(self.device),
192
+ attention_mask=inputs['attention_mask'].to(self.device),
193
+ max_length=16,
194
+ )
195
+
196
+ answer = self.tokenizer.decode(outs[0], skip_special_tokens=True)
197
+ return answer
198
+
199
+
200
+ class E2EQGPipeline:
201
+ def __init__(
202
+ self,
203
+ model: PreTrainedModel,
204
+ tokenizer: PreTrainedTokenizer,
205
+ use_cuda: bool
206
+ ) :
207
+
208
+ self.model = model
209
+ self.tokenizer = tokenizer
210
+
211
+ self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
212
+ self.model.to(self.device)
213
+
214
+ assert self.model.__class__.__name__ in ["T5ForConditionalGeneration", "BartForConditionalGeneration"]
215
+
216
+ if "T5ForConditionalGeneration" in self.model.__class__.__name__:
217
+ self.model_type = "t5"
218
+ else:
219
+ self.model_type = "bart"
220
+
221
+ self.default_generate_kwargs = {
222
+ "max_length": 256,
223
+ "num_beams": 4,
224
+ "length_penalty": 1.5,
225
+ "no_repeat_ngram_size": 3,
226
+ "early_stopping": True,
227
+ }
228
+
229
+ def __call__(self, context: str, **generate_kwargs):
230
+ inputs = self._prepare_inputs_for_e2e_qg(context)
231
+
232
+ # TODO: when overrding default_generate_kwargs all other arguments need to be passsed
233
+ # find a better way to do this
234
+ if not generate_kwargs:
235
+ generate_kwargs = self.default_generate_kwargs
236
+
237
+ input_length = inputs["input_ids"].shape[-1]
238
+
239
+ # max_length = generate_kwargs.get("max_length", 256)
240
+ # if input_length < max_length:
241
+ # logger.warning(
242
+ # "Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
243
+ # max_length, input_length
244
+ # )
245
+ # )
246
+
247
+ outs = self.model.generate(
248
+ input_ids=inputs['input_ids'].to(self.device),
249
+ attention_mask=inputs['attention_mask'].to(self.device),
250
+ **generate_kwargs
251
+ )
252
+
253
+ prediction = self.tokenizer.decode(outs[0], skip_special_tokens=True)
254
+ questions = prediction.split("<sep>")
255
+ questions = [question.strip() for question in questions[:-1]]
256
+ return questions
257
+
258
+ def _prepare_inputs_for_e2e_qg(self, context):
259
+ source_text = f"generate questions: {context}"
260
+ if self.model_type == "t5":
261
+ source_text = source_text + " </s>"
262
+
263
+ inputs = self._tokenize([source_text], padding=False)
264
+ return inputs
265
+
266
+ def _tokenize(
267
+ self,
268
+ inputs,
269
+ padding=True,
270
+ truncation=True,
271
+ add_special_tokens=True,
272
+ max_length=512
273
+ ):
274
+ inputs = self.tokenizer.batch_encode_plus(
275
+ inputs,
276
+ max_length=max_length,
277
+ add_special_tokens=add_special_tokens,
278
+ truncation=truncation,
279
+ padding="max_length" if padding else False,
280
+ pad_to_max_length=padding,
281
+ return_tensors="pt"
282
+ )
283
+ return inputs
284
+
285
+
286
+ SUPPORTED_TASKS = {
287
+ "question-generation": {
288
+ "impl": QGPipeline,
289
+ "default": {
290
+ "model": "valhalla/t5-small-qg-hl",
291
+ "ans_model": "valhalla/t5-small-qa-qg-hl",
292
+ }
293
+ },
294
+ "multitask-qa-qg": {
295
+ "impl": MultiTaskQAQGPipeline,
296
+ "default": {
297
+ "model": "valhalla/t5-small-qa-qg-hl",
298
+ }
299
+ },
300
+ "e2e-qg": {
301
+ "impl": E2EQGPipeline,
302
+ "default": {
303
+ "model": "valhalla/t5-small-e2e-qg",
304
+ }
305
+ }
306
+ }
307
+
308
+ def pipeline(
309
+ task: str,
310
+ model: Optional = None,
311
+ tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
312
+ qg_format: Optional[str] = "highlight",
313
+ ans_model: Optional = None,
314
+ ans_tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
315
+ use_cuda: Optional[bool] = True,
316
+ **kwargs,
317
+ ):
318
+ # Retrieve the task
319
+ if task not in SUPPORTED_TASKS:
320
+ raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
321
+
322
+ targeted_task = SUPPORTED_TASKS[task]
323
+ task_class = targeted_task["impl"]
324
+
325
+ # Use default model/config/tokenizer for the task if no model is provided
326
+ if model is None:
327
+ model = targeted_task["default"]["model"]
328
+
329
+ # Try to infer tokenizer from model or config name (if provided as str)
330
+ if tokenizer is None:
331
+ if isinstance(model, str):
332
+ tokenizer = model
333
+ else:
334
+ # Impossible to guest what is the right tokenizer here
335
+ raise Exception(
336
+ "Impossible to guess which tokenizer to use. "
337
+ "Please provided a PretrainedTokenizer class or a path/identifier to a pretrained tokenizer."
338
+ )
339
+
340
+ # Instantiate tokenizer if needed
341
+ if isinstance(tokenizer, (str, tuple)):
342
+ if isinstance(tokenizer, tuple):
343
+ # For tuple we have (tokenizer name, {kwargs})
344
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], **tokenizer[1])
345
+ else:
346
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
347
+
348
+ # Instantiate model if needed
349
+ if isinstance(model, str):
350
+ model = AutoModelForSeq2SeqLM.from_pretrained(model)
351
+
352
+ if task == "question-generation":
353
+ if ans_model is None:
354
+ # load default ans model
355
+ ans_model = targeted_task["default"]["ans_model"]
356
+ ans_tokenizer = AutoTokenizer.from_pretrained(ans_model)
357
+ ans_model = AutoModelForSeq2SeqLM.from_pretrained(ans_model)
358
+ else:
359
+ # Try to infer tokenizer from model or config name (if provided as str)
360
+ if ans_tokenizer is None:
361
+ if isinstance(ans_model, str):
362
+ ans_tokenizer = ans_model
363
+ else:
364
+ # Impossible to guest what is the right tokenizer here
365
+ raise Exception(
366
+ "Impossible to guess which tokenizer to use. "
367
+ "Please provided a PretrainedTokenizer class or a path/identifier to a pretrained tokenizer."
368
+ )
369
+
370
+ # Instantiate tokenizer if needed
371
+ if isinstance(ans_tokenizer, (str, tuple)):
372
+ if isinstance(ans_tokenizer, tuple):
373
+ # For tuple we have (tokenizer name, {kwargs})
374
+ ans_tokenizer = AutoTokenizer.from_pretrained(ans_tokenizer[0], **ans_tokenizer[1])
375
+ else:
376
+ ans_tokenizer = AutoTokenizer.from_pretrained(ans_tokenizer)
377
+
378
+ if isinstance(ans_model, str):
379
+ ans_model = AutoModelForSeq2SeqLM.from_pretrained(ans_model)
380
+
381
+ if task == "e2e-qg":
382
+ return task_class(model=model, tokenizer=tokenizer, use_cuda=use_cuda)
383
+ elif task == "question-generation":
384
+ return task_class(model=model, tokenizer=tokenizer, ans_model=ans_model, ans_tokenizer=ans_tokenizer, qg_format=qg_format, use_cuda=use_cuda)
385
+ else:
386
+ return task_class(model=model, tokenizer=tokenizer, ans_model=model, ans_tokenizer=tokenizer, qg_format=qg_format, use_cuda=use_cuda)
question_generation/prepare_data.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, List, Optional
5
+
6
+ import torch
7
+ import nlp
8
+ from transformers import T5Tokenizer, BartTokenizer, HfArgumentParser
9
+
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @dataclass
15
+ class DataTrainingArguments:
16
+ """
17
+ Arguments pertaining to what data we are going to input our model for training and eval.
18
+ """
19
+ task: str = field(
20
+ metadata={"help": "Which task 'qa', 'qg', 'e2e_qg', 'ans_ext', 'multi'. 'multi' means 'qa', 'qg', 'ans_ext' tasks"},
21
+ )
22
+ model_type: str = field(metadata={"help": "One of 't5', 'bart'"})
23
+ dataset_path: Optional[str] = field(
24
+ default="data/squad_multitask",
25
+ metadata={"help": "Path for dataset directory"},
26
+ )
27
+ train_file_name: Optional[str] = field(
28
+ default=None,
29
+ metadata={"help": "name for cached train dataset"},
30
+ )
31
+ valid_file_name: Optional[str] = field(
32
+ default=None,
33
+ metadata={"help": "name for cached valid dataset"},
34
+ )
35
+ valid_for_qg_only: bool = field(
36
+ default=False,
37
+ metadata={"help": "For multitask dataset valid split should contain only qg task or all tasks."}
38
+ )
39
+ qg_format: Optional[str] = field(
40
+ default='highlight_qg_format',
41
+ metadata={"help": "How to format inputs for que generation, 'highlight_qg_format' or 'prepend_qg_format'"},
42
+ )
43
+ max_source_length: Optional[int] = field(
44
+ default=512,
45
+ metadata={"help": "Max input length for the source text"},
46
+ )
47
+ max_target_length: Optional[int] = field(
48
+ default=32,
49
+ metadata={"help": "Max input length for the target text"},
50
+ )
51
+
52
+ class DataProcessor:
53
+ def __init__(self, tokenizer, model_type="t5", max_source_length=512, max_target_length=32):
54
+ self.tokenizer = tokenizer
55
+ self.max_source_length = max_source_length
56
+ self.max_target_length = max_target_length
57
+ self.model_type = model_type
58
+ self.hl_token = "<hl>"
59
+
60
+ if model_type == "t5":
61
+ self.sep_token = "<sep>"
62
+ elif model_type == "bart":
63
+ self.sep_token = "<sep>"
64
+ else:
65
+ self.sep_token = "[SEP]"
66
+
67
+ def process(self, dataset):
68
+ if self.model_type == "t5":
69
+ dataset = dataset.map(self._add_eos_examples)
70
+
71
+ dataset = dataset.map(self._add_special_tokens)
72
+ dataset = dataset.map(self._convert_to_features, batched=True)
73
+
74
+ return dataset
75
+
76
+ def _add_eos_examples(self, example):
77
+ example['source_text'] = example['source_text'] + " </s>"
78
+ example['target_text'] = example['target_text'] + " </s>"
79
+ return example
80
+
81
+ def _add_special_tokens(self, example):
82
+ example['source_text'] = example['source_text'].replace("{hl_token}", self.hl_token)
83
+ example['target_text'] = example['target_text'].replace("{sep_token}", self.sep_token)
84
+ return example
85
+
86
+ # tokenize the examples
87
+ def _convert_to_features(self, example_batch):
88
+ source_encoding = self.tokenizer.batch_encode_plus(
89
+ example_batch['source_text'],
90
+ max_length=self.max_source_length,
91
+ padding='max_length',
92
+ pad_to_max_length=True,
93
+ truncation=True,
94
+ )
95
+ target_encoding = self.tokenizer.batch_encode_plus(
96
+ example_batch['target_text'],
97
+ max_length=self.max_target_length,
98
+ padding='max_length',
99
+ pad_to_max_length=True,
100
+ truncation=True,
101
+ )
102
+
103
+ encodings = {
104
+ 'source_ids': source_encoding['input_ids'],
105
+ 'target_ids': target_encoding['input_ids'],
106
+ 'attention_mask': source_encoding['attention_mask'],
107
+ }
108
+
109
+ return encodings
110
+
111
+
112
+ def filter_qa(example):
113
+ return example['task'] == 'qa'
114
+
115
+ def filter_qg(example):
116
+ return example['task'] == 'qg'
117
+
118
+ def filter_e2e_qg(example):
119
+ return example['task'] == 'e2e_qg'
120
+
121
+ def filter_ans_ext(example):
122
+ return example['task'] == 'ans_ext'
123
+
124
+ def filter_multi(example):
125
+ return example['task'] != 'e2e_qg'
126
+
127
+
128
+ TASK_TO_FILTER_FN = {
129
+ 'qa': filter_qa,
130
+ 'qg': filter_qg,
131
+ 'e2e_qg': filter_e2e_qg,
132
+ 'ans_ext': filter_ans_ext,
133
+ 'multi': filter_multi
134
+ }
135
+
136
+
137
+ def main():
138
+ parser = HfArgumentParser((DataTrainingArguments,))
139
+
140
+ data_args = parser.parse_args_into_dataclasses()[0]
141
+
142
+ logging.basicConfig(
143
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
144
+ datefmt="%m/%d/%Y %H:%M:%S",
145
+ level=logging.INFO
146
+ )
147
+
148
+ if data_args.model_type == 't5':
149
+ tokenizer = T5Tokenizer.from_pretrained("t5-base")
150
+ else:
151
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
152
+
153
+ tokenizer.add_tokens(['<sep>', '<hl>'])
154
+
155
+ train_dataset = nlp.load_dataset(data_args.dataset_path, name=data_args.qg_format, split=nlp.Split.TRAIN)
156
+ valid_dataset = nlp.load_dataset(data_args.dataset_path, name=data_args.qg_format, split=nlp.Split.VALIDATION)
157
+
158
+ processor = DataProcessor(
159
+ tokenizer,
160
+ model_type=data_args.model_type,
161
+ max_source_length=data_args.max_source_length,
162
+ max_target_length=data_args.max_target_length
163
+ )
164
+
165
+ train_dataset = train_dataset.filter(TASK_TO_FILTER_FN[data_args.task])
166
+ if data_args.task == 'multi' and data_args.valid_for_qg_only:
167
+ logger.info("processing valid data only for qg task")
168
+ valid_dataset = valid_dataset.filter(filter_qg)
169
+ else:
170
+ valid_dataset = valid_dataset.filter(TASK_TO_FILTER_FN[data_args.task])
171
+
172
+
173
+ train_dataset = processor.process(train_dataset)
174
+ valid_dataset = processor.process(valid_dataset)
175
+
176
+ columns = ["source_ids", "target_ids", "attention_mask"]
177
+ train_dataset.set_format(type='torch', columns=columns)
178
+ valid_dataset.set_format(type='torch', columns=columns)
179
+
180
+ if data_args.train_file_name is None:
181
+ train_file_name = f"train_data_{data_args.task}_{data_args.qg_format}_{data_args.model_type}.pt"
182
+ train_path = os.path.join("data", train_file_name)
183
+
184
+ valid_file_name = f"valid_data_{data_args.task}_{data_args.qg_format}_{data_args.model_type}.pt"
185
+ valid_path = os.path.join("data", valid_file_name)
186
+ else:
187
+ train_path = os.path.join("data", data_args.train_file_name)
188
+ valid_path = os.path.join("data", data_args.valid_file_name)
189
+
190
+ torch.save(train_dataset, train_path)
191
+ logger.info(f"saved train dataset at {train_path}")
192
+
193
+ torch.save(valid_dataset, valid_path)
194
+ logger.info(f"saved validation dataset at {valid_path}")
195
+
196
+ tokenizer_path = f"{data_args.model_type}_qg_tokenizer"
197
+ if not os.path.exists(tokenizer_path):
198
+ os.mkdir(tokenizer_path)
199
+ tokenizer.save_pretrained(tokenizer_path)
200
+ logger.info(f"saved tokenizer at {tokenizer_path}")
201
+
202
+
203
+ if __name__ == "__main__":
204
+ main()
question_generation/question_generation.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
question_generation/run_qg.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import json
3
+ import logging
4
+ import os
5
+ import sys
6
+ from dataclasses import dataclass, field
7
+ from typing import Dict, List, Optional
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ from transformers import (
13
+ AutoModelForSeq2SeqLM,
14
+ AutoTokenizer,
15
+ T5Tokenizer,
16
+ BartTokenizer,
17
+ HfArgumentParser,
18
+ DataCollator,
19
+ TrainingArguments,
20
+ set_seed,
21
+ )
22
+
23
+ from trainer import Trainer
24
+ from data_collator import T2TDataCollator
25
+ from utils import freeze_embeds, assert_not_all_frozen
26
+
27
+ MODEL_TYPE_TO_TOKENIZER = {
28
+ "t5": T5Tokenizer,
29
+ "bart": BartTokenizer,
30
+ }
31
+
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class ModelArguments:
38
+ """
39
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
40
+ """
41
+
42
+ model_name_or_path: str = field(
43
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
44
+ )
45
+ model_type: str = field(metadata={"help": "One of 't5', 'bart'"})
46
+ tokenizer_name_or_path: Optional[str] = field(
47
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
48
+ )
49
+ cache_dir: Optional[str] = field(
50
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
51
+ )
52
+ label_smoothing: Optional[float] = field(
53
+ default=0,
54
+ metadata={"help": "label smoothing rate, set to > 0 if you want to enable lable smoothing"}
55
+ )
56
+ freeze_embeds: bool = field(
57
+ default=False,
58
+ metadata={"help": "Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."}
59
+ )
60
+
61
+ @dataclass
62
+ class DataTrainingArguments:
63
+ """
64
+ Arguments pertaining to what data we are going to input our model for training and eval.
65
+ """
66
+ train_file_path: str = field(
67
+ metadata={"help": "Path for cached train dataset"},
68
+ )
69
+ valid_file_path: str = field(
70
+ metadata={"help": "Path for cached valid dataset"},
71
+ )
72
+ data_dir: Optional[str] = field(
73
+ default=None,
74
+ metadata={"help": "Path for data files"},
75
+ )
76
+ task: Optional[str] = field(
77
+ default=None,
78
+ metadata={"help": "Which task 'qa', 'qg', 'e2e_qg', 'ans_ext', 'multi'. 'multi' means 'qa', 'qg', 'ans_ext' tasks"},
79
+ )
80
+ qg_format: Optional[str] = field(
81
+ default='prepend_qg_format',
82
+ metadata={"help": "How to format inputs for que generation, 'highlight_qg_format' or 'prepend_qg_format'"},
83
+ )
84
+ max_source_length: Optional[int] = field(
85
+ default=512,
86
+ metadata={"help": "Max input length for the source text"},
87
+ )
88
+ max_target_length: Optional[int] = field(
89
+ default=32,
90
+ metadata={"help": "Max input length for the target text"},
91
+ )
92
+
93
+
94
+ def main(args_file=None):
95
+ # See all possible arguments in src/transformers/training_args.py
96
+ # or by passing the --help flag to this script.
97
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
98
+
99
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
100
+
101
+ if (len(sys.argv) == 2 and sys.argv[1].endswith(".json")) or args_file is not None:
102
+ # If we pass only one argument to the script and it's the path to a json file,
103
+ # let's parse it to get our arguments.
104
+ args_file_path = os.path.abspath(sys.argv[1]) if args_file is None else args_file
105
+ model_args, data_args, training_args = parser.parse_json_file(json_file=args_file_path)
106
+ else:
107
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
108
+
109
+ assert model_args.model_type in list(MODEL_TYPE_TO_TOKENIZER.keys()), "model type should be 't5' or 'bart'"
110
+
111
+ if (
112
+ os.path.exists(training_args.output_dir)
113
+ and os.listdir(training_args.output_dir)
114
+ and training_args.do_train
115
+ and not training_args.overwrite_output_dir
116
+ ):
117
+ raise ValueError(
118
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
119
+ )
120
+
121
+ # Setup logging
122
+ logging.basicConfig(
123
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
124
+ datefmt="%m/%d/%Y %H:%M:%S",
125
+ level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
126
+ )
127
+ logger.warning(
128
+ "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
129
+ training_args.local_rank,
130
+ training_args.device,
131
+ training_args.n_gpu,
132
+ bool(training_args.local_rank != -1),
133
+ training_args.fp16,
134
+ )
135
+ logger.info("Training/evaluation parameters %s", training_args)
136
+
137
+ # Set seed
138
+ set_seed(training_args.seed)
139
+
140
+ # Set project name
141
+ os.environ["WANDB_PROJECT"] = "question-generation"
142
+
143
+ # Load pretrained model and tokenizer
144
+ #
145
+ # Distributed training:
146
+ # The .from_pretrained methods guarantee that only one local process can concurrently
147
+ # download model & vocab.
148
+ tokenizer_cls = MODEL_TYPE_TO_TOKENIZER[model_args.model_type]
149
+ tokenizer = tokenizer_cls.from_pretrained(
150
+ model_args.tokenizer_name_or_path if model_args.tokenizer_name_or_path else model_args.model_name_or_path,
151
+ cache_dir=model_args.cache_dir,
152
+ )
153
+ model = AutoModelForSeq2SeqLM.from_pretrained(
154
+ model_args.model_name_or_path,
155
+ cache_dir=model_args.cache_dir,
156
+ )
157
+
158
+ model.resize_token_embeddings(len(tokenizer))
159
+
160
+ if model_args.freeze_embeds:
161
+ logger.info("freezing embeddings of the model")
162
+ freeze_embeds(model)
163
+ assert_not_all_frozen(model)
164
+
165
+ # Get datasets
166
+ logger.info('loading dataset')
167
+
168
+ train_dataset = torch.load(data_args.train_file_path) if training_args.do_train else None
169
+ valid_dataset = torch.load(data_args.valid_file_path) if training_args.do_eval else None
170
+
171
+ logger.info('finished loading dataset')
172
+
173
+ # Initialize data_collator
174
+ data_collator = T2TDataCollator(
175
+ tokenizer=tokenizer,
176
+ model_type=model_args.model_type,
177
+ mode="training",
178
+ using_tpu=training_args.tpu_num_cores is not None
179
+ )
180
+
181
+ # Initialize our Trainer
182
+ trainer = Trainer(
183
+ model=model,
184
+ args=training_args,
185
+ train_dataset=train_dataset,
186
+ eval_dataset=valid_dataset,
187
+ data_collator=data_collator,
188
+ prediction_loss_only=True,
189
+ label_smoothing=model_args.label_smoothing
190
+ )
191
+
192
+ # disable wandb console logs
193
+ logging.getLogger('wandb.run_manager').setLevel(logging.WARNING)
194
+
195
+ # Training
196
+ if training_args.do_train:
197
+ trainer.train(
198
+ model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
199
+ )
200
+ trainer.save_model()
201
+ # For convenience, we also re-save the tokenizer to the same directory,
202
+ # so that you can share your model easily on huggingface.co/models =)
203
+ if trainer.is_world_master():
204
+ tokenizer.save_pretrained(training_args.output_dir)
205
+
206
+ # Evaluation
207
+ results = {}
208
+ if training_args.do_eval and training_args.local_rank in [-1, 0]:
209
+ logger.info("*** Evaluate ***")
210
+
211
+ eval_output = trainer.evaluate()
212
+
213
+ output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
214
+ with open(output_eval_file, "w") as writer:
215
+ logger.info("***** Eval results *****")
216
+ for key in sorted(eval_output.keys()):
217
+ logger.info(" %s = %s", key, str(eval_output[key]))
218
+ writer.write("%s = %s\n" % (key, str(eval_output[key])))
219
+
220
+ results.update(eval_output)
221
+
222
+ return results
223
+
224
+
225
+ def _mp_fn(index):
226
+ # For xla_spawn (TPUs)
227
+ main()
228
+
229
+ def run_qg(args_dict):
230
+ with open("args.json", 'w') as f:
231
+ json.dump(args_dict, f)
232
+
233
+ main(args_file="args.json")
234
+
235
+ if __name__ == "__main__":
236
+ main()
question_generation/trainer.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from transformers import Trainer as HFTrainer
7
+ from transformers.file_utils import is_apex_available
8
+
9
+ if is_apex_available():
10
+ from apex import amp
11
+
12
+ from utils import label_smoothed_nll_loss
13
+
14
+ class Trainer(HFTrainer):
15
+ def __init__(self, label_smoothing: float = 0, **kwargs):
16
+ super().__init__(**kwargs)
17
+ self.label_smoothing = label_smoothing
18
+
19
+ # override to support label smoothing
20
+ def _training_step(
21
+ self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer
22
+ ) -> float:
23
+ model.train()
24
+ for k, v in inputs.items():
25
+ if isinstance(v, torch.Tensor):
26
+ inputs[k] = v.to(self.args.device)
27
+
28
+
29
+ # Our model outputs do not work with DataParallel, so forcing return tuple.
30
+ if isinstance(model, nn.DataParallel):
31
+ inputs["return_tuple"] = True
32
+
33
+ if self.label_smoothing == 0:
34
+ outputs = model(**inputs)
35
+ loss = outputs[0] # model outputs are always tuple in transformers (see doc)
36
+ else:
37
+ labels = inputs.pop("labels")
38
+ labels[labels == -100] = model.config.pad_token_id
39
+ outputs = model(**inputs)
40
+ lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1)
41
+ loss, nll_loss = label_smoothed_nll_loss(
42
+ lprobs, labels, self.label_smoothing, ignore_index=model.config.pad_token_id
43
+ )
44
+
45
+ if self.args.n_gpu > 1:
46
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
47
+ if self.args.gradient_accumulation_steps > 1:
48
+ loss = loss / self.args.gradient_accumulation_steps
49
+
50
+ if self.args.fp16:
51
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
52
+ scaled_loss.backward()
53
+ else:
54
+ loss.backward()
55
+
56
+ return loss.item()
question_generation/utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Dict, Iterable, List
2
+ from torch import nn
3
+
4
+ # these functions are taken from transformers repo
5
+ def grad_status(model: nn.Module) -> Iterable:
6
+ return (par.requires_grad for par in model.parameters())
7
+
8
+ def freeze_params(model: nn.Module):
9
+ for par in model.parameters():
10
+ par.requires_grad = False
11
+
12
+ def freeze_embeds(model: nn.Module):
13
+ """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
14
+ try:
15
+ freeze_params(model.model.shared)
16
+ for d in [model.model.encoder, model.model.decoder]:
17
+ freeze_params(d.embed_positions)
18
+ freeze_params(d.embed_tokens)
19
+ except AttributeError:
20
+ freeze_params(model.shared)
21
+ for d in [model.encoder, model.decoder]:
22
+ freeze_params(d.embed_tokens)
23
+
24
+ def assert_not_all_frozen(model):
25
+ model_grads: List[bool] = list(grad_status(model))
26
+ npars = len(model_grads)
27
+ assert any(model_grads), f"none of {npars} weights require grad"
28
+
29
+ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
30
+ """From fairseq"""
31
+ if target.dim() == lprobs.dim() - 1:
32
+ target = target.unsqueeze(-1)
33
+ nll_loss = -lprobs.gather(dim=-1, index=target)
34
+ smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
35
+ if ignore_index is not None:
36
+ pad_mask = target.eq(ignore_index)
37
+ nll_loss.masked_fill_(pad_mask, 0.0)
38
+ smooth_loss.masked_fill_(pad_mask, 0.0)
39
+ bs = pad_mask.long().sum()
40
+ else:
41
+ nll_loss = nll_loss.squeeze(-1)
42
+ smooth_loss = smooth_loss.squeeze(-1)
43
+ bs = lprobs.shape[0]
44
+
45
+ nll_loss = nll_loss.sum() # mean()? Scared to break other math.
46
+ smooth_loss = smooth_loss.sum()
47
+ eps_i = epsilon / lprobs.size(-1)
48
+ loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
49
+ return loss / bs, nll_loss / bs