Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
from bitarray import bitarray | |
class IndexManager(): | |
def __init__(self, dim): | |
self.dim = dim | |
def save(self, tensor, path_prefix): | |
torch.save(tensor, path_prefix) | |
def save_bitarray(self, bitarray, path_prefix): | |
with open(path_prefix, "wb") as f: | |
bitarray.tofile(f) | |
def load_index_part(filename, verbose=True): | |
part = torch.load(filename) | |
if type(part) == list: # for backward compatibility | |
part = torch.cat(part) | |
return part | |
def load_compressed_index_part(filename, dim, bits): | |
a = bitarray() | |
with open(filename, "rb") as f: | |
a.fromfile(f) | |
n = len(a) // dim // bits | |
part = torch.tensor(np.frombuffer(a.tobytes(), dtype=np.uint8)) # TODO: isn't from_numpy(.) faster? | |
part = part.reshape((n, int(np.ceil(dim * bits / 8)))) | |
return part | |