Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import tqdm | |
from colbert.indexing.loaders import load_doclens | |
from colbert.utils.utils import print_message, flatten | |
def optimize_ivf(orig_ivf, orig_ivf_lengths, index_path): | |
print_message("#> Optimizing IVF to store map from centroids to list of pids..") | |
print_message("#> Building the emb2pid mapping..") | |
all_doclens = load_doclens(index_path, flatten=False) | |
# assert self.num_embeddings == sum(flatten(all_doclens)) | |
all_doclens = flatten(all_doclens) | |
total_num_embeddings = sum(all_doclens) | |
emb2pid = torch.zeros(total_num_embeddings, dtype=torch.int) | |
""" | |
EVENTUALLY: Use two tensors. emb2pid_offsets will have every 256th element. | |
emb2pid_delta will have the delta from the corresponding offset, | |
""" | |
offset_doclens = 0 | |
for pid, dlength in enumerate(all_doclens): | |
emb2pid[offset_doclens: offset_doclens + dlength] = pid | |
offset_doclens += dlength | |
print_message("len(emb2pid) =", len(emb2pid)) | |
ivf = emb2pid[orig_ivf] | |
unique_pids_per_centroid = [] | |
ivf_lengths = [] | |
offset = 0 | |
for length in tqdm.tqdm(orig_ivf_lengths.tolist()): | |
pids = torch.unique(ivf[offset:offset+length]) | |
unique_pids_per_centroid.append(pids) | |
ivf_lengths.append(pids.shape[0]) | |
offset += length | |
ivf = torch.cat(unique_pids_per_centroid) | |
ivf_lengths = torch.tensor(ivf_lengths) | |
original_ivf_path = os.path.join(index_path, 'ivf.pt') | |
optimized_ivf_path = os.path.join(index_path, 'ivf.pid.pt') | |
torch.save((ivf, ivf_lengths), optimized_ivf_path) | |
print_message(f"#> Saved optimized IVF to {optimized_ivf_path}") | |
if os.path.exists(original_ivf_path): | |
print_message(f"#> Original IVF at path \"{original_ivf_path}\" can now be removed") | |
return ivf, ivf_lengths | |