欧卫
'add_app_files'
58627fa
import os
import torch
import tqdm
from colbert.indexing.loaders import load_doclens
from colbert.utils.utils import print_message, flatten
def optimize_ivf(orig_ivf, orig_ivf_lengths, index_path):
print_message("#> Optimizing IVF to store map from centroids to list of pids..")
print_message("#> Building the emb2pid mapping..")
all_doclens = load_doclens(index_path, flatten=False)
# assert self.num_embeddings == sum(flatten(all_doclens))
all_doclens = flatten(all_doclens)
total_num_embeddings = sum(all_doclens)
emb2pid = torch.zeros(total_num_embeddings, dtype=torch.int)
"""
EVENTUALLY: Use two tensors. emb2pid_offsets will have every 256th element.
emb2pid_delta will have the delta from the corresponding offset,
"""
offset_doclens = 0
for pid, dlength in enumerate(all_doclens):
emb2pid[offset_doclens: offset_doclens + dlength] = pid
offset_doclens += dlength
print_message("len(emb2pid) =", len(emb2pid))
ivf = emb2pid[orig_ivf]
unique_pids_per_centroid = []
ivf_lengths = []
offset = 0
for length in tqdm.tqdm(orig_ivf_lengths.tolist()):
pids = torch.unique(ivf[offset:offset+length])
unique_pids_per_centroid.append(pids)
ivf_lengths.append(pids.shape[0])
offset += length
ivf = torch.cat(unique_pids_per_centroid)
ivf_lengths = torch.tensor(ivf_lengths)
original_ivf_path = os.path.join(index_path, 'ivf.pt')
optimized_ivf_path = os.path.join(index_path, 'ivf.pid.pt')
torch.save((ivf, ivf_lengths), optimized_ivf_path)
print_message(f"#> Saved optimized IVF to {optimized_ivf_path}")
if os.path.exists(original_ivf_path):
print_message(f"#> Original IVF at path \"{original_ivf_path}\" can now be removed")
return ivf, ivf_lengths