Spaces:
Runtime error
Runtime error
File size: 878 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import torch
import numpy as np
from bitarray import bitarray
class IndexManager():
def __init__(self, dim):
self.dim = dim
def save(self, tensor, path_prefix):
torch.save(tensor, path_prefix)
def save_bitarray(self, bitarray, path_prefix):
with open(path_prefix, "wb") as f:
bitarray.tofile(f)
def load_index_part(filename, verbose=True):
part = torch.load(filename)
if type(part) == list: # for backward compatibility
part = torch.cat(part)
return part
def load_compressed_index_part(filename, dim, bits):
a = bitarray()
with open(filename, "rb") as f:
a.fromfile(f)
n = len(a) // dim // bits
part = torch.tensor(np.frombuffer(a.tobytes(), dtype=np.uint8)) # TODO: isn't from_numpy(.) faster?
part = part.reshape((n, int(np.ceil(dim * bits / 8))))
return part
|