# coding: UTF-8 import os import time import torch import random from tqdm import tqdm from datetime import timedelta from ebart import PegasusSummarizer def get_time_dif(start_time): end_time = time.time() time_dif = end_time - start_time return timedelta(seconds=int(round(time_dif))) ## 获取张量数据 class DataProcessor(object): def __init__(self, dataPath, device,summarizerModel, tokenizer, batch_size, max_seq_len, seed): self.seed = seed self.device = device self.tokenizer = tokenizer self.summarizerModel=summarizerModel self.batch_size = batch_size self.max_seq_len = max_seq_len self.data = self.load(dataPath) self.index = 0 self.residue = False self.num_samples = len(self.data[0]) self.num_batches = self.num_samples // self.batch_size if self.num_samples % self.batch_size != 0: self.residue = True def load(self, path): contents = [] labels = [] labels_map = {} labels_path = os.path.join(os.path.dirname(path), "label.txt") with open(labels_path, mode="r", encoding="UTF-8") as f: for index, line in enumerate(tqdm(f)): line = line.strip() if line: labels_map[line] = index print("labels_map [{}]".format(labels_map)) with open(path, mode="r", encoding="UTF-8") as f: for line in tqdm(f): line = line.strip() if not line: continue if '\t' not in line: continue content, label = line.rsplit('\t', 1) mapped_label = labels_map.get(label) if mapped_label is not None and isinstance(mapped_label, int) and mapped_label >= 0: #content= self.summarizerModel.generate_summary(content, 128, 64) #print(content) contents.append(content) labels.append(mapped_label) else: print("not match label [{}:{}]".format(content,mapped_label)) #random shuffle index = list(range(len(labels))) random.seed(self.seed) random.shuffle(index) contents = [contents[_] for _ in index] labels = [labels[_] for _ in index] print("load datas contents label [{}:{}]".format(len(contents),len(labels))) return (contents, labels) def __next__(self): if self.residue and self.index == self.num_batches: batch_x = self.data[0][self.index * self.batch_size: self.num_samples] batch_y = self.data[1][self.index * self.batch_size: self.num_samples] batch = self._to_tensor(batch_x, batch_y) self.index += 1 return batch elif self.index >= self.num_batches: self.index = 0 raise StopIteration else: batch_x = self.data[0][self.index * self.batch_size: (self.index + 1) * self.batch_size] batch_y = self.data[1][self.index * self.batch_size: (self.index + 1) * self.batch_size] batch = self._to_tensor(batch_x, batch_y) self.index += 1 return batch def _to_tensor(self, batch_x, batch_y): inputs = self.tokenizer.batch_encode_plus( batch_x, padding="max_length", max_length=self.max_seq_len, truncation="longest_first", return_tensors="pt") inputs = inputs.to(self.device) labels = torch.LongTensor(batch_y).to(self.device) return (inputs, labels) def __iter__(self): return self def __len__(self): if self.residue: return self.num_batches + 1 else: return self.num_batches