|
|
|
|
|
import os
|
|
import time
|
|
import torch
|
|
import random
|
|
from tqdm import tqdm
|
|
from datetime import timedelta
|
|
from ebart import PegasusSummarizer
|
|
|
|
def get_time_dif(start_time):
|
|
end_time = time.time()
|
|
time_dif = end_time - start_time
|
|
return timedelta(seconds=int(round(time_dif)))
|
|
|
|
|
|
class DataProcessor(object):
|
|
def __init__(self, dataPath, device,summarizerModel, tokenizer, batch_size, max_seq_len, seed):
|
|
self.seed = seed
|
|
self.device = device
|
|
self.tokenizer = tokenizer
|
|
self.summarizerModel=summarizerModel
|
|
self.batch_size = batch_size
|
|
self.max_seq_len = max_seq_len
|
|
|
|
self.data = self.load(dataPath)
|
|
|
|
self.index = 0
|
|
self.residue = False
|
|
self.num_samples = len(self.data[0])
|
|
self.num_batches = self.num_samples // self.batch_size
|
|
if self.num_samples % self.batch_size != 0:
|
|
self.residue = True
|
|
|
|
def load(self, path):
|
|
contents = []
|
|
labels = []
|
|
labels_map = {}
|
|
|
|
labels_path = os.path.join(os.path.dirname(path), "label.txt")
|
|
with open(labels_path, mode="r", encoding="UTF-8") as f:
|
|
for index, line in enumerate(tqdm(f)):
|
|
line = line.strip()
|
|
if line:
|
|
labels_map[line] = index
|
|
|
|
print("labels_map [{}]".format(labels_map))
|
|
with open(path, mode="r", encoding="UTF-8") as f:
|
|
for line in tqdm(f):
|
|
line = line.strip()
|
|
if not line: continue
|
|
if '\t' not in line: continue
|
|
content, label = line.rsplit('\t', 1)
|
|
mapped_label = labels_map.get(label)
|
|
if mapped_label is not None and isinstance(mapped_label, int) and mapped_label >= 0:
|
|
|
|
|
|
contents.append(content)
|
|
labels.append(mapped_label)
|
|
else:
|
|
print("not match label [{}:{}]".format(content,mapped_label))
|
|
|
|
|
|
index = list(range(len(labels)))
|
|
random.seed(self.seed)
|
|
random.shuffle(index)
|
|
contents = [contents[_] for _ in index]
|
|
labels = [labels[_] for _ in index]
|
|
print("load datas contents label [{}:{}]".format(len(contents),len(labels)))
|
|
return (contents, labels)
|
|
|
|
def __next__(self):
|
|
if self.residue and self.index == self.num_batches:
|
|
batch_x = self.data[0][self.index * self.batch_size: self.num_samples]
|
|
batch_y = self.data[1][self.index * self.batch_size: self.num_samples]
|
|
batch = self._to_tensor(batch_x, batch_y)
|
|
self.index += 1
|
|
return batch
|
|
elif self.index >= self.num_batches:
|
|
self.index = 0
|
|
raise StopIteration
|
|
else:
|
|
batch_x = self.data[0][self.index * self.batch_size: (self.index + 1) * self.batch_size]
|
|
batch_y = self.data[1][self.index * self.batch_size: (self.index + 1) * self.batch_size]
|
|
batch = self._to_tensor(batch_x, batch_y)
|
|
self.index += 1
|
|
return batch
|
|
|
|
def _to_tensor(self, batch_x, batch_y):
|
|
inputs = self.tokenizer.batch_encode_plus(
|
|
batch_x,
|
|
padding="max_length",
|
|
max_length=self.max_seq_len,
|
|
truncation="longest_first",
|
|
return_tensors="pt")
|
|
inputs = inputs.to(self.device)
|
|
labels = torch.LongTensor(batch_y).to(self.device)
|
|
return (inputs, labels)
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __len__(self):
|
|
if self.residue:
|
|
return self.num_batches + 1
|
|
else:
|
|
return self.num_batches
|
|
|