jerome-white's picture
More robust vector store management
a9490de
raw
history blame
4.07 kB
import uuid
import hashlib
import warnings
import itertools as it
import functools as ft
from pathlib import Path
from openai import BadRequestError
# from ._logging import Logger
class FileObject:
_window = 20
def __init__(self, path):
self.fp = path.open('rb')
self.chunk = 2 ** self._window
def close(self):
self.fp.close()
@ft.cached_property
def checksum(self):
csum = hashlib.blake2b()
while True:
data = self.fp.read(self.chunk)
if not data:
break
csum.update(data)
self.fp.seek(0)
return csum.hexdigest()
class FileStream:
def __init__(self, paths):
self.paths = paths
self.streams = []
def __len__(self):
return len(self.streams)
def __iter__(self):
for p in self.paths:
stream = FileObject(p)
self.streams.append(stream)
yield stream
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
for s in self.streams:
s.close()
self.streams.clear()
class VectorStoreManager:
def __init__(self, client, vector_store_id):
self.client = client
self.vector_store_id = vector_store_id
def __iter__(self):
kwargs = {}
while True:
vs_files = self.client.beta.vector_stores.files.list(
vector_store_id=self.vector_store_id,
**kwargs,
)
for f in vs_files.data:
yield f.id
if not vs_files.has_more:
break
kwargs['after'] = vs_files.last_id
def cleanup(self):
files = list(self)
for f in files:
self.client.files.delete(f)
self.client.beta.vector_stores.delete(self.vector_store_id)
class FileManager(VectorStoreManager):
def __init__(self, client, prefix, batch_size=20):
super().__init__(client, None)
self.prefix = prefix
self.batch_size = batch_size
self.storage = set()
def __bool__(self):
return bool(self.storage)
def __call__(self, paths):
files = []
self.test_and_setup()
for p in self.each(paths):
with FileStream(p) as stream:
for s in stream:
if s.checksum not in self.storage:
files.append(s)
if files:
self.put(files)
files.clear()
return '\n'.join(self.ls())
def ls(self):
for i in self:
entry = self.client.files.retrieve(i)
yield entry.filename
def cleanup(self):
if self.storage:
assert self.vector_store_id is not None
super().cleanup()
self.storage.clear()
self.vector_store_id = None
def test_and_setup(self):
if self:
msg = f'Vector store already exists ({self.vector_store_id})'
warnings.warn(msg)
else:
name = f'{self.prefix}{uuid.uuid4()}'
vector_store = self.client.beta.vector_stores.create(
name=name,
)
self.vector_store_id = vector_store.id
def each(self, paths):
left = 0
while left < len(paths):
right = left + self.batch_size
yield list(map(Path, it.islice(paths, left, right)))
left = right
def put(self, fobjs):
files = [ x.fp for x in fobjs ]
try:
bat = self.client.beta.vector_stores.file_batches.upload_and_poll(
vector_store_id=self.vector_store_id,
files=files,
)
except BadRequestError as err:
raise InterruptedError(err.body['message']) from err
if bat.file_counts.completed != len(files):
err = f'Error uploading documents: {bat.file_counts}'
raise InterruptedError(err)
self.storage.update(x.checksum for x in fobjs)