|
import os |
|
from collections import deque |
|
|
|
import torch |
|
from torch.utils.data import Dataset |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CNNDMDataset(Dataset): |
|
"""Abstracts the dataset used to train seq2seq models. |
|
|
|
The class will process the documents that are located in the specified |
|
folder. The preprocessing will work on any document that is reasonably |
|
formatted. On the CNN/DailyMail dataset it will extract both the story |
|
and the summary. |
|
|
|
CNN/Daily News: |
|
|
|
The CNN/Daily News raw datasets are downloaded from [1]. The stories are |
|
stored in different files; the summary appears at the end of the story as |
|
sentences that are prefixed by the special `@highlight` line. To process |
|
the data, untar both datasets in the same folder, and pass the path to this |
|
folder as the "data_dir argument. The formatting code was inspired by [2]. |
|
|
|
[1] https://cs.nyu.edu/~kcho/ |
|
[2] https://github.com/abisee/cnn-dailymail/ |
|
""" |
|
|
|
def __init__(self, path="", prefix="train"): |
|
"""We initialize the class by listing all the documents to summarize. |
|
Files are not read in memory due to the size of some datasets (like CNN/DailyMail). |
|
""" |
|
assert os.path.isdir(path) |
|
|
|
self.documents = [] |
|
story_filenames_list = os.listdir(path) |
|
for story_filename in story_filenames_list: |
|
if "summary" in story_filename: |
|
continue |
|
path_to_story = os.path.join(path, story_filename) |
|
if not os.path.isfile(path_to_story): |
|
continue |
|
self.documents.append(path_to_story) |
|
|
|
def __len__(self): |
|
"""Returns the number of documents.""" |
|
return len(self.documents) |
|
|
|
def __getitem__(self, idx): |
|
document_path = self.documents[idx] |
|
document_name = document_path.split("/")[-1] |
|
with open(document_path, encoding="utf-8") as source: |
|
raw_story = source.read() |
|
story_lines, summary_lines = process_story(raw_story) |
|
return document_name, story_lines, summary_lines |
|
|
|
|
|
def process_story(raw_story): |
|
"""Extract the story and summary from a story file. |
|
|
|
Arguments: |
|
raw_story (str): content of the story file as an utf-8 encoded string. |
|
|
|
Raises: |
|
IndexError: If the story is empty or contains no highlights. |
|
""" |
|
nonempty_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])) |
|
|
|
|
|
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines] |
|
|
|
|
|
story_lines = [] |
|
lines = deque(nonempty_lines) |
|
while True: |
|
try: |
|
element = lines.popleft() |
|
if element.startswith("@highlight"): |
|
break |
|
story_lines.append(element) |
|
except IndexError: |
|
|
|
|
|
return story_lines, [] |
|
|
|
|
|
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines)) |
|
|
|
return story_lines, summary_lines |
|
|
|
|
|
def _add_missing_period(line): |
|
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', "\u2019", "\u2019", ")"] |
|
if line.startswith("@highlight"): |
|
return line |
|
if line[-1] in END_TOKENS: |
|
return line |
|
return line + "." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def truncate_or_pad(sequence, block_size, pad_token_id): |
|
"""Adapt the source and target sequences' lengths to the block size. |
|
If the sequence is shorter we append padding token to the right of the sequence. |
|
""" |
|
if len(sequence) > block_size: |
|
return sequence[:block_size] |
|
else: |
|
sequence.extend([pad_token_id] * (block_size - len(sequence))) |
|
return sequence |
|
|
|
|
|
def build_mask(sequence, pad_token_id): |
|
"""Builds the mask. The attention mechanism will only attend to positions |
|
with value 1.""" |
|
mask = torch.ones_like(sequence) |
|
idx_pad_tokens = sequence == pad_token_id |
|
mask[idx_pad_tokens] = 0 |
|
return mask |
|
|
|
|
|
def encode_for_summarization(story_lines, summary_lines, tokenizer): |
|
"""Encode the story and summary lines, and join them |
|
as specified in [1] by using `[SEP] [CLS]` tokens to separate |
|
sentences. |
|
""" |
|
story_lines_token_ids = [tokenizer.encode(line) for line in story_lines] |
|
story_token_ids = [token for sentence in story_lines_token_ids for token in sentence] |
|
summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines] |
|
summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence] |
|
|
|
return story_token_ids, summary_token_ids |
|
|
|
|
|
def compute_token_type_ids(batch, separator_token_id): |
|
"""Segment embeddings as described in [1] |
|
|
|
The values {0,1} were found in the repository [2]. |
|
|
|
Attributes: |
|
batch: torch.Tensor, size [batch_size, block_size] |
|
Batch of input. |
|
separator_token_id: int |
|
The value of the token that separates the segments. |
|
|
|
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders." |
|
arXiv preprint arXiv:1908.08345 (2019). |
|
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217) |
|
""" |
|
batch_embeddings = [] |
|
for sequence in batch: |
|
sentence_num = -1 |
|
embeddings = [] |
|
for s in sequence: |
|
if s == separator_token_id: |
|
sentence_num += 1 |
|
embeddings.append(sentence_num % 2) |
|
batch_embeddings.append(embeddings) |
|
return torch.tensor(batch_embeddings) |
|
|