import json import os import random import numpy as np from paddle.io import Dataset from .imaug import create_operators, transform class PubTabDataSet(Dataset): def __init__(self, config, mode, logger, seed=None): super(PubTabDataSet, self).__init__() self.logger = logger global_config = config["Global"] dataset_config = config[mode]["dataset"] loader_config = config[mode]["loader"] label_file_path = dataset_config.pop("label_file_path") self.data_dir = dataset_config["data_dir"] self.do_shuffle = loader_config["shuffle"] self.do_hard_select = False if "hard_select" in loader_config: self.do_hard_select = loader_config["hard_select"] self.hard_prob = loader_config["hard_prob"] if self.do_hard_select: self.img_select_prob = self.load_hard_select_prob() self.table_select_type = None if "table_select_type" in loader_config: self.table_select_type = loader_config["table_select_type"] self.table_select_prob = loader_config["table_select_prob"] self.seed = seed logger.info("Initialize indexs of datasets:%s" % label_file_path) with open(label_file_path, "rb") as f: self.data_lines = f.readlines() self.data_idx_order_list = list(range(len(self.data_lines))) if mode.lower() == "train": self.shuffle_data_random() self.ops = create_operators(dataset_config["transforms"], global_config) ratio_list = dataset_config.get("ratio_list", [1.0]) self.need_reset = True in [x < 1 for x in ratio_list] def shuffle_data_random(self): if self.do_shuffle: random.seed(self.seed) random.shuffle(self.data_lines) return def __getitem__(self, idx): try: data_line = self.data_lines[idx] data_line = data_line.decode("utf-8").strip("\n") info = json.loads(data_line) file_name = info["filename"] select_flag = True if self.do_hard_select: prob = self.img_select_prob[file_name] if prob < random.uniform(0, 1): select_flag = False if self.table_select_type: structure = info["html"]["structure"]["tokens"].copy() structure_str = "".join(structure) table_type = "simple" if "colspan" in structure_str or "rowspan" in structure_str: table_type = "complex" if table_type == "complex": if self.table_select_prob < random.uniform(0, 1): select_flag = False if select_flag: cells = info["html"]["cells"].copy() structure = info["html"]["structure"].copy() img_path = os.path.join(self.data_dir, file_name) data = {"img_path": img_path, "cells": cells, "structure": structure} if not os.path.exists(img_path): raise Exception("{} does not exist!".format(img_path)) with open(data["img_path"], "rb") as f: img = f.read() data["image"] = img outs = transform(data, self.ops) else: outs = None except Exception as e: self.logger.error( "When parsing line {}, error happened with msg: {}".format(data_line, e) ) outs = None if outs is None: return self.__getitem__(np.random.randint(self.__len__())) return outs def __len__(self): return len(self.data_idx_order_list)