sooks commited on
Commit
2a093da
·
1 Parent(s): 3f7db8d

Create dataset.py

Browse files
Files changed (1) hide show
  1. detector/dataset.py +86 -0
detector/dataset.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ from typing import List
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ from tqdm import tqdm
8
+ from transformers import PreTrainedTokenizer
9
+
10
+ from .download import download
11
+
12
+
13
+ def load_texts(data_file, expected_size=None):
14
+ texts = []
15
+
16
+ for line in tqdm(open(data_file), total=expected_size, desc=f'Loading {data_file}'):
17
+ texts.append(json.loads(line)['text'])
18
+
19
+ return texts
20
+
21
+
22
+ class Corpus:
23
+ def __init__(self, name, data_dir='data', skip_train=False):
24
+ download(name, data_dir=data_dir)
25
+ self.name = name
26
+ self.train = load_texts(f'{data_dir}/{name}.train.jsonl', expected_size=250000) if not skip_train else None
27
+ self.test = load_texts(f'{data_dir}/{name}.test.jsonl', expected_size=5000)
28
+ self.valid = load_texts(f'{data_dir}/{name}.valid.jsonl', expected_size=5000)
29
+
30
+
31
+ class EncodedDataset(Dataset):
32
+ def __init__(self, real_texts: List[str], fake_texts: List[str], tokenizer: PreTrainedTokenizer,
33
+ max_sequence_length: int = None, min_sequence_length: int = None, epoch_size: int = None,
34
+ token_dropout: float = None, seed: int = None):
35
+ self.real_texts = real_texts
36
+ self.fake_texts = fake_texts
37
+ self.tokenizer = tokenizer
38
+ self.max_sequence_length = max_sequence_length
39
+ self.min_sequence_length = min_sequence_length
40
+ self.epoch_size = epoch_size
41
+ self.token_dropout = token_dropout
42
+ self.random = np.random.RandomState(seed)
43
+
44
+ def __len__(self):
45
+ return self.epoch_size or len(self.real_texts) + len(self.fake_texts)
46
+
47
+ def __getitem__(self, index):
48
+ if self.epoch_size is not None:
49
+ label = self.random.randint(2)
50
+ texts = [self.fake_texts, self.real_texts][label]
51
+ text = texts[self.random.randint(len(texts))]
52
+ else:
53
+ if index < len(self.real_texts):
54
+ text = self.real_texts[index]
55
+ label = 1
56
+ else:
57
+ text = self.fake_texts[index - len(self.real_texts)]
58
+ label = 0
59
+
60
+ tokens = self.tokenizer.encode(text)
61
+
62
+ if self.max_sequence_length is None:
63
+ tokens = tokens[:self.tokenizer.max_len - 2]
64
+ else:
65
+ output_length = min(len(tokens), self.max_sequence_length)
66
+ if self.min_sequence_length:
67
+ output_length = self.random.randint(min(self.min_sequence_length, len(tokens)), output_length + 1)
68
+ start_index = 0 if len(tokens) <= output_length else self.random.randint(0, len(tokens) - output_length + 1)
69
+ end_index = start_index + output_length
70
+ tokens = tokens[start_index:end_index]
71
+
72
+ if self.token_dropout:
73
+ dropout_mask = self.random.binomial(1, self.token_dropout, len(tokens)).astype(np.bool)
74
+ tokens = np.array(tokens)
75
+ tokens[dropout_mask] = self.tokenizer.unk_token_id
76
+ tokens = tokens.tolist()
77
+
78
+ if self.max_sequence_length is None or len(tokens) == self.max_sequence_length:
79
+ mask = torch.ones(len(tokens) + 2)
80
+ return torch.tensor([self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]), mask, label
81
+
82
+ padding = [self.tokenizer.pad_token_id] * (self.max_sequence_length - len(tokens))
83
+ tokens = torch.tensor([self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id] + padding)
84
+ mask = torch.ones(tokens.shape[0])
85
+ mask[-len(padding):] = 0
86
+ return tokens, mask, label