|
import torch |
|
import random |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset |
|
from transformers import BertTokenizer |
|
from data import get_data |
|
import itertools |
|
|
|
tokenizer = BertTokenizer.from_pretrained("bert-it-1/bert-it-vocab.txt") |
|
|
|
|
|
class BERTDataset(Dataset): |
|
|
|
def __init__(self, tokenizer: BertTokenizer=tokenizer, data_pair: list=get_data('datasets/movie_conversations.txt', "datasets/movie_lines.txt"), seq_len: int=128) -> None: |
|
super().__init__() |
|
|
|
self.tokenizer = tokenizer |
|
self.seq_len = seq_len |
|
self.corpus_lines = len(data_pair) |
|
self.lines = data_pair |
|
|
|
def __len__(self): |
|
return self.corpus_lines |
|
|
|
def __getitem__(self, item): |
|
|
|
|
|
t1, t2, is_next_label = self.get_sent(item) |
|
|
|
|
|
t1_random, t1_label = self.random_word(t1) |
|
t2_random, t2_label = self.random_word(t2) |
|
|
|
|
|
|
|
t1 = [self.tokenizer.vocab['[CLS]']] + t1_random + [self.tokenizer.vocab['[SEP]']] |
|
t2 = t2_random + [self.tokenizer.vocab['[SEP]']] |
|
t1_label = [self.tokenizer.vocab['[PAD]']] + t1_label + [self.tokenizer.vocab['[PAD]']] |
|
t2_label = t2_label + [self.tokenizer.vocab['[PAD]']] |
|
|
|
|
|
|
|
segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len] |
|
bert_input = (t1 + t2)[:self.seq_len] |
|
bert_label = (t1_label + t2_label)[:self.seq_len] |
|
padding = [self.tokenizer.vocab['[PAD]'] for _ in range(self.seq_len - len(bert_input))] |
|
bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding) |
|
|
|
output = {"bert_input": bert_input, |
|
"bert_label": bert_label, |
|
"segment_label": segment_label, |
|
"is_next": is_next_label} |
|
|
|
return {key: torch.tensor(value) for key, value in output.items()} |
|
|
|
def random_word(self, sentence): |
|
tokens = sentence.split() |
|
output_label = [] |
|
output = [] |
|
|
|
|
|
for i, token in enumerate(tokens): |
|
prob = random.random() |
|
|
|
|
|
token_id = self.tokenizer(token)['input_ids'][1:-1] |
|
|
|
if prob < 0.15: |
|
prob /= 0.15 |
|
|
|
|
|
if prob < 0.8: |
|
for i in range(len(token_id)): |
|
output.append(self.tokenizer.vocab['[MASK]']) |
|
|
|
|
|
elif prob < 0.9: |
|
for i in range(len(token_id)): |
|
output.append(random.randrange(len(self.tokenizer.vocab))) |
|
|
|
|
|
else: |
|
output.append(token_id) |
|
|
|
output_label.append(token_id) |
|
|
|
else: |
|
output.append(token_id) |
|
for i in range(len(token_id)): |
|
output_label.append(0) |
|
|
|
|
|
output = list(itertools.chain(*[[x] if not isinstance(x, list) else x for x in output])) |
|
output_label = list(itertools.chain(*[[x] if not isinstance(x, list) else x for x in output_label])) |
|
assert len(output) == len(output_label) |
|
return output, output_label |
|
|
|
def get_sent(self, index): |
|
'''return random sentence pair''' |
|
t1, t2 = self.get_corpus_line(index) |
|
|
|
|
|
if random.random() > 0.5: |
|
return t1, t2, 1 |
|
else: |
|
return t1, self.get_random_line(), 0 |
|
|
|
def get_corpus_line(self, item): |
|
'''return sentence pair''' |
|
return self.lines[item][0], self.lines[item][1] |
|
|
|
def get_random_line(self): |
|
'''return random single sentence''' |
|
return self.lines[random.randrange(len(self.lines))][1] |