ir_chinese_medqa / colbert /indexing /index_manager.py
欧卫
'add_app_files'
58627fa
raw
history blame contribute delete
878 Bytes
import torch
import numpy as np
from bitarray import bitarray
class IndexManager():
def __init__(self, dim):
self.dim = dim
def save(self, tensor, path_prefix):
torch.save(tensor, path_prefix)
def save_bitarray(self, bitarray, path_prefix):
with open(path_prefix, "wb") as f:
bitarray.tofile(f)
def load_index_part(filename, verbose=True):
part = torch.load(filename)
if type(part) == list: # for backward compatibility
part = torch.cat(part)
return part
def load_compressed_index_part(filename, dim, bits):
a = bitarray()
with open(filename, "rb") as f:
a.fromfile(f)
n = len(a) // dim // bits
part = torch.tensor(np.frombuffer(a.tobytes(), dtype=np.uint8)) # TODO: isn't from_numpy(.) faster?
part = part.reshape((n, int(np.ceil(dim * bits / 8))))
return part