dusense / preprocess.py
yangfan
feat(*): add all for like bert
6bce1f7
raw
history blame
2.94 kB
# coding: UTF-8
import time
import torch
import random
from tqdm import tqdm
from datetime import timedelta
def get_time_dif(start_time):
end_time = time.time()
time_dif = end_time - start_time
return timedelta(seconds=int(round(time_dif)))
class DataProcessor(object):
def __init__(self, path, device, tokenizer, batch_size, max_seq_len, seed):
self.seed = seed
self.device = device
self.tokenizer = tokenizer
self.batch_size = batch_size
self.max_seq_len = max_seq_len
self.data = self.load(path)
self.index = 0
self.residue = False
self.num_samples = len(self.data[0])
self.num_batches = self.num_samples // self.batch_size
if self.num_samples % self.batch_size != 0:
self.residue = True
def load(self, path):
contents = []
labels = []
with open(path, mode="r", encoding="UTF-8") as f:
for line in tqdm(f):
line = line.strip()
if not line: continue
if line.find('\t') == -1: continue
content, label = line.split("\t")
contents.append(content)
labels.append(int(label))
#random shuffle
index = list(range(len(labels)))
random.seed(self.seed)
random.shuffle(index)
contents = [contents[_] for _ in index]
labels = [labels[_] for _ in index]
return (contents, labels)
def __next__(self):
if self.residue and self.index == self.num_batches:
batch_x = self.data[0][self.index * self.batch_size: self.num_samples]
batch_y = self.data[1][self.index * self.batch_size: self.num_samples]
batch = self._to_tensor(batch_x, batch_y)
self.index += 1
return batch
elif self.index >= self.num_batches:
self.index = 0
raise StopIteration
else:
batch_x = self.data[0][self.index * self.batch_size: (self.index + 1) * self.batch_size]
batch_y = self.data[1][self.index * self.batch_size: (self.index + 1) * self.batch_size]
batch = self._to_tensor(batch_x, batch_y)
self.index += 1
return batch
def _to_tensor(self, batch_x, batch_y):
inputs = self.tokenizer.batch_encode_plus(
batch_x,
padding="max_length",
max_length=self.max_seq_len,
truncation="longest_first",
return_tensors="pt")
inputs = inputs.to(self.device)
labels = torch.LongTensor(batch_y).to(self.device)
return (inputs, labels)
def __iter__(self):
return self
def __len__(self):
if self.residue:
return self.num_batches + 1
else:
return self.num_batches