Spaces:
Runtime error
Runtime error
File size: 2,900 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import os
import queue
import ujson
import threading
from contextlib import contextmanager
from colbert.indexing.codecs.residual import ResidualCodec
from colbert.utils.utils import print_message
class IndexSaver():
def __init__(self, config):
self.config = config
def save_codec(self, codec):
codec.save(index_path=self.config.index_path_)
def load_codec(self):
return ResidualCodec.load(index_path=self.config.index_path_)
def try_load_codec(self):
try:
ResidualCodec.load(index_path=self.config.index_path_)
return True
except Exception as e:
return False
def check_chunk_exists(self, chunk_idx):
# TODO: Verify that the chunk has the right amount of data?
doclens_path = os.path.join(self.config.index_path_, f'doclens.{chunk_idx}.json')
if not os.path.exists(doclens_path):
return False
metadata_path = os.path.join(self.config.index_path_, f'{chunk_idx}.metadata.json')
if not os.path.exists(metadata_path):
return False
path_prefix = os.path.join(self.config.index_path_, str(chunk_idx))
codes_path = f'{path_prefix}.codes.pt'
if not os.path.exists(codes_path):
return False
residuals_path = f'{path_prefix}.residuals.pt' # f'{path_prefix}.residuals.bn'
if not os.path.exists(residuals_path):
return False
return True
@contextmanager
def thread(self):
self.codec = self.load_codec()
self.saver_queue = queue.Queue(maxsize=3)
thread = threading.Thread(target=self._saver_thread)
thread.start()
try:
yield
finally:
self.saver_queue.put(None)
thread.join()
del self.saver_queue
del self.codec
def save_chunk(self, chunk_idx, offset, embs, doclens):
compressed_embs = self.codec.compress(embs)
self.saver_queue.put((chunk_idx, offset, compressed_embs, doclens))
def _saver_thread(self):
for args in iter(self.saver_queue.get, None):
self._write_chunk_to_disk(*args)
def _write_chunk_to_disk(self, chunk_idx, offset, compressed_embs, doclens):
path_prefix = os.path.join(self.config.index_path_, str(chunk_idx))
compressed_embs.save(path_prefix)
doclens_path = os.path.join(self.config.index_path_, f'doclens.{chunk_idx}.json')
with open(doclens_path, 'w') as output_doclens:
ujson.dump(doclens, output_doclens)
metadata_path = os.path.join(self.config.index_path_, f'{chunk_idx}.metadata.json')
with open(metadata_path, 'w') as output_metadata:
metadata = {'passage_offset': offset, 'num_passages': len(doclens), 'num_embeddings': len(compressed_embs)}
ujson.dump(metadata, output_metadata)
|