dusense / preprocess.py
saily's picture
retrain the models
6f993e2
raw
history blame
3.96 kB
# coding: UTF-8
import os
import time
import torch
import random
from tqdm import tqdm
from datetime import timedelta
from ebart import PegasusSummarizer
def get_time_dif(start_time):
end_time = time.time()
time_dif = end_time - start_time
return timedelta(seconds=int(round(time_dif)))
## 获取张量数据
class DataProcessor(object):
def __init__(self, dataPath, device,summarizerModel, tokenizer, batch_size, max_seq_len, seed):
self.seed = seed
self.device = device
self.tokenizer = tokenizer
self.summarizerModel=summarizerModel
self.batch_size = batch_size
self.max_seq_len = max_seq_len
self.data = self.load(dataPath)
self.index = 0
self.residue = False
self.num_samples = len(self.data[0])
self.num_batches = self.num_samples // self.batch_size
if self.num_samples % self.batch_size != 0:
self.residue = True
def load(self, path):
contents = []
labels = []
labels_map = {}
labels_path = os.path.join(os.path.dirname(path), "label.txt")
with open(labels_path, mode="r", encoding="UTF-8") as f:
for index, line in enumerate(tqdm(f)):
line = line.strip()
if line:
labels_map[line] = index
print("labels_map [{}]".format(labels_map))
with open(path, mode="r", encoding="UTF-8") as f:
for line in tqdm(f):
line = line.strip()
if not line: continue
if '\t' not in line: continue
content, label = line.rsplit('\t', 1)
mapped_label = labels_map.get(label)
if mapped_label is not None and isinstance(mapped_label, int) and mapped_label >= 0:
#content= self.summarizerModel.generate_summary(content, 128, 64)
#print(content)
contents.append(content)
labels.append(mapped_label)
else:
print("not match label [{}:{}]".format(content,mapped_label))
#random shuffle
index = list(range(len(labels)))
random.seed(self.seed)
random.shuffle(index)
contents = [contents[_] for _ in index]
labels = [labels[_] for _ in index]
print("load datas contents label [{}:{}]".format(len(contents),len(labels)))
return (contents, labels)
def __next__(self):
if self.residue and self.index == self.num_batches:
batch_x = self.data[0][self.index * self.batch_size: self.num_samples]
batch_y = self.data[1][self.index * self.batch_size: self.num_samples]
batch = self._to_tensor(batch_x, batch_y)
self.index += 1
return batch
elif self.index >= self.num_batches:
self.index = 0
raise StopIteration
else:
batch_x = self.data[0][self.index * self.batch_size: (self.index + 1) * self.batch_size]
batch_y = self.data[1][self.index * self.batch_size: (self.index + 1) * self.batch_size]
batch = self._to_tensor(batch_x, batch_y)
self.index += 1
return batch
def _to_tensor(self, batch_x, batch_y):
inputs = self.tokenizer.batch_encode_plus(
batch_x,
padding="max_length",
max_length=self.max_seq_len,
truncation="longest_first",
return_tensors="pt")
inputs = inputs.to(self.device)
labels = torch.LongTensor(batch_y).to(self.device)
return (inputs, labels)
def __iter__(self):
return self
def __len__(self):
if self.residue:
return self.num_batches + 1
else:
return self.num_batches