|
|
|
|
|
import time
|
|
import torch
|
|
import random
|
|
from tqdm import tqdm
|
|
from datetime import timedelta
|
|
|
|
def get_time_dif(start_time):
|
|
end_time = time.time()
|
|
time_dif = end_time - start_time
|
|
return timedelta(seconds=int(round(time_dif)))
|
|
|
|
class DataProcessor(object):
|
|
def __init__(self, path, device, tokenizer, batch_size, max_seq_len, seed):
|
|
self.seed = seed
|
|
self.device = device
|
|
self.tokenizer = tokenizer
|
|
self.batch_size = batch_size
|
|
self.max_seq_len = max_seq_len
|
|
|
|
self.data = self.load(path)
|
|
|
|
self.index = 0
|
|
self.residue = False
|
|
self.num_samples = len(self.data[0])
|
|
self.num_batches = self.num_samples // self.batch_size
|
|
if self.num_samples % self.batch_size != 0:
|
|
self.residue = True
|
|
|
|
def load(self, path):
|
|
contents = []
|
|
labels = []
|
|
with open(path, mode="r", encoding="UTF-8") as f:
|
|
for line in tqdm(f):
|
|
line = line.strip()
|
|
if not line: continue
|
|
if line.find('\t') == -1: continue
|
|
content, label = line.split("\t")
|
|
contents.append(content)
|
|
labels.append(int(label))
|
|
|
|
index = list(range(len(labels)))
|
|
random.seed(self.seed)
|
|
random.shuffle(index)
|
|
contents = [contents[_] for _ in index]
|
|
labels = [labels[_] for _ in index]
|
|
return (contents, labels)
|
|
|
|
def __next__(self):
|
|
if self.residue and self.index == self.num_batches:
|
|
batch_x = self.data[0][self.index * self.batch_size: self.num_samples]
|
|
batch_y = self.data[1][self.index * self.batch_size: self.num_samples]
|
|
batch = self._to_tensor(batch_x, batch_y)
|
|
self.index += 1
|
|
return batch
|
|
elif self.index >= self.num_batches:
|
|
self.index = 0
|
|
raise StopIteration
|
|
else:
|
|
batch_x = self.data[0][self.index * self.batch_size: (self.index + 1) * self.batch_size]
|
|
batch_y = self.data[1][self.index * self.batch_size: (self.index + 1) * self.batch_size]
|
|
batch = self._to_tensor(batch_x, batch_y)
|
|
self.index += 1
|
|
return batch
|
|
|
|
def _to_tensor(self, batch_x, batch_y):
|
|
inputs = self.tokenizer.batch_encode_plus(
|
|
batch_x,
|
|
padding="max_length",
|
|
max_length=self.max_seq_len,
|
|
truncation="longest_first",
|
|
return_tensors="pt")
|
|
inputs = inputs.to(self.device)
|
|
labels = torch.LongTensor(batch_y).to(self.device)
|
|
return (inputs, labels)
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __len__(self):
|
|
if self.residue:
|
|
return self.num_batches + 1
|
|
else:
|
|
return self.num_batches
|
|
|