File size: 1,173 Bytes
7962ed0
 
 
 
9051af7
7962ed0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40e6e04
7962ed0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import h5py
from tqdm import tqdm
import numpy as np
import codecs
from knowledge.utils import file_hash


class TextDB:
    def __init__(self, text_db):
        self.feature, self.text = self.load(text_db)
        self.file_hash = file_hash(text_db)

    def load(self, text_db):
        with h5py.File(text_db, 'r') as f:
            db_size = 0
            for i in range(len(f)):
                db_size += len(f[f"{i}/feature"])
            _, d = f[f"0/feature"].shape

        with h5py.File(text_db, 'r') as f:
            feature = np.zeros((db_size, d), dtype=np.float16)
            text = []
            N = 0
            for i in tqdm(range(len(f)), desc="Load text DB", dynamic_ncols=True, mininterval=1.0):
                fi = f[f"{i}/feature"][:]
                feature[N:N+len(fi)] = fi
                N += len(fi)

                text.extend(f[f"{i}/text"][:])
        text = [codecs.decode(t) for t in text]

        return feature, text
    
    def __getitem__(self, idx):
        f = self.feature[idx]

        try:
            t = [self.text[i] for i in idx]
        except TypeError:
            t = self.text[idx]
        
        return f, t