File size: 2,941 Bytes
6bce1f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# coding: UTF-8

import time
import torch
import random
from tqdm import tqdm
from datetime import timedelta

def get_time_dif(start_time):
    end_time = time.time()
    time_dif = end_time - start_time
    return timedelta(seconds=int(round(time_dif)))

class DataProcessor(object):
    def __init__(self, path, device, tokenizer, batch_size, max_seq_len, seed):
        self.seed = seed
        self.device = device
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.max_seq_len = max_seq_len

        self.data = self.load(path)

        self.index = 0
        self.residue = False
        self.num_samples = len(self.data[0])
        self.num_batches = self.num_samples // self.batch_size
        if self.num_samples % self.batch_size != 0:
            self.residue = True

    def load(self, path):
        contents = []
        labels = []
        with open(path, mode="r", encoding="UTF-8") as f:
            for line in tqdm(f):
                line = line.strip()
                if not line:    continue
                if line.find('\t') == -1:   continue
                content, label = line.split("\t")
                contents.append(content)
                labels.append(int(label))
        #random shuffle
        index = list(range(len(labels)))
        random.seed(self.seed)
        random.shuffle(index)
        contents = [contents[_] for _ in index]
        labels = [labels[_] for _ in index]
        return (contents, labels)

    def __next__(self):
        if self.residue and self.index == self.num_batches:
            batch_x = self.data[0][self.index * self.batch_size: self.num_samples]
            batch_y = self.data[1][self.index * self.batch_size: self.num_samples]
            batch = self._to_tensor(batch_x, batch_y)
            self.index += 1
            return batch
        elif self.index >= self.num_batches:
            self.index = 0
            raise StopIteration
        else:
            batch_x = self.data[0][self.index * self.batch_size: (self.index + 1) * self.batch_size]
            batch_y = self.data[1][self.index * self.batch_size: (self.index + 1) * self.batch_size]
            batch = self._to_tensor(batch_x, batch_y)
            self.index += 1
            return batch

    def _to_tensor(self, batch_x, batch_y):
        inputs = self.tokenizer.batch_encode_plus(
            batch_x,
            padding="max_length",
            max_length=self.max_seq_len,
            truncation="longest_first",
            return_tensors="pt")
        inputs = inputs.to(self.device)
        labels = torch.LongTensor(batch_y).to(self.device)
        return (inputs, labels)

    def __iter__(self):
        return self
    
    def __len__(self):
        if self.residue:
            return self.num_batches + 1
        else:
            return self.num_batches