Spaces:
Runtime error
Runtime error
import h5py | |
import numpy as np | |
from functools import partial | |
from utils.gen_utils import map_nlist, vround | |
import regex as re | |
from spacyface.simple_spacy_token import SimpleSpacyToken | |
from data_processing.sentence_data_wrapper import SentenceH5Data, TokenH5Data | |
from utils.f import ifnone | |
ZERO_BUFFER = 12 # Number of decimal places each index takes | |
main_key = r"{:0" + str(ZERO_BUFFER) + r"}" | |
def to_idx(idx:int): | |
return main_key.format(idx) | |
def zip_len_check(*iters): | |
"""Zip iterables with a check that they are all the same length""" | |
if len(iters) < 2: | |
raise ValueError(f"Expected at least 2 iterables to combine. Got {len(iters)} iterables") | |
n = len(iters[0]) | |
for i in iters: | |
n_ = len(i) | |
if n_ != n: | |
raise ValueError(f"Expected all iterations to have len {n} but found {n_}") | |
return zip(*iters) | |
class CorpusDataWrapper: | |
"""A wrapper for both the token embeddings and the head context. | |
This class allows access into an HDF5 file designed according to the data/processing module's contents as if it were | |
and in memory dictionary. | |
""" | |
def __init__(self, fname, name=None): | |
"""Open an hdf5 file of the format designed and provide easy access to its contents""" | |
# For iterating through the dataset | |
self.__curr = 0 | |
self.__name = ifnone(name, "CorpusData") | |
self.fname = fname | |
self.data = h5py.File(fname, 'r') | |
main_keys = self.data.keys() | |
self.__len = len(main_keys) | |
assert self.__len > 0, "Cannot process an empty file" | |
embeds = self[0].embeddings | |
self.embedding_dim = embeds.shape[-1] | |
self.n_layers = embeds.shape[0] - 1 # 1 was added for the input layer | |
self.refmap, self.total_vectors = self._init_vector_map() | |
def __del__(self): | |
try: self.data.close() | |
# If run as a script, won't be able to close because of an import error | |
except ImportError: pass | |
except AttributeError: | |
print(f"Never successfully loaded {self.fname}") | |
def __iter__(self): | |
return self | |
def __len__(self): | |
return self.__len | |
def __next__(self): | |
if self.__curr >= self.__len: | |
self.__curr = 0 | |
raise StopIteration | |
out = self[self.__curr] | |
self.__curr += 1 | |
return out | |
def __getitem__(self, idx): | |
"""Index into the embeddings""" | |
if isinstance(idx, slice): | |
start = idx.start or 0 | |
step = idx.step or 1 | |
stop = idx.stop or (self.__len - 1) | |
stop = min(stop, self.__len) | |
i = start | |
out = [] | |
while i < stop: | |
out.append(self[i]) | |
i += step | |
return out | |
elif isinstance(idx, int): | |
if idx < 0: i = self.__len + idx | |
else: i = idx | |
key = to_idx(i) | |
return SentenceH5Data(self.data[key]) | |
else: | |
raise NotImplementedError | |
def __repr__(self): | |
return f"{self.__name}: containing {self.__len} items" | |
def _init_vector_map(self): | |
"""Create main hashmap for all vectors to get their metadata. | |
TODO Initialization is a little slow... Should this be stored in a separate hdf5 file? | |
This doesn't change. Check for special hdf5 file and see if it exists already. If it does, open it. | |
If not, create it | |
""" | |
refmap = {} | |
print("Initializing reference map for embedding vector...") | |
n_vec = 0 | |
for z, sentence in enumerate(self): | |
for i in range(len(sentence)): | |
refs = TokenH5Data(sentence, i) | |
refmap[n_vec] = refs | |
n_vec += 1 | |
return refmap, n_vec | |
def extract(self, layer): | |
"""Extract embeddings from a particular layer from the dataset | |
For all examples | |
""" | |
embeddings = [] | |
for i, embeds in enumerate(self): | |
embeddings.append(embeds[layer]) | |
out = np.vstack(embeddings) | |
return out | |
def find(self, vec_num): | |
"""Find a vector's metadata (by id) in the hdf5 file. Needed to find sentence info and other attr""" | |
return self.refmap[vec_num] | |
def find2d(self, idxs): | |
"""Find a vector's metadata in the hdf5 file. Needed to find sentence info and other attr""" | |
out = [[self.refmap[i] for i in idx] for idx in idxs] | |
return out |