import h5py from tqdm import tqdm import numpy as np import codecs from knowledge.utils import file_hash class TextDB: def __init__(self, text_db): self.feature, self.text = self.load(text_db) self.file_hash = file_hash(text_db) def load(self, text_db): with h5py.File(text_db, 'r') as f: db_size = 0 for i in range(len(f)): db_size += len(f[f"{i}/feature"]) _, d = f[f"0/feature"].shape with h5py.File(text_db, 'r') as f: feature = np.zeros((db_size, d), dtype=np.float16) text = [] N = 0 for i in tqdm(range(len(f)), desc="Load text DB", dynamic_ncols=True, mininterval=1.0): fi = f[f"{i}/feature"][:] feature[N:N+len(fi)] = fi N += len(fi) text.extend(f[f"{i}/text"][:]) text = [codecs.decode(t) for t in text] return feature, text def __getitem__(self, idx): f = self.feature[idx] try: t = [self.text[i] for i in idx] except TypeError: t = self.text[idx] return f, t