Spaces:
Runtime error
Runtime error
import logging | |
import multiprocessing | |
from multiprocessing.connection import Connection | |
from typing import Generator, Callable | |
from hypothesis import given | |
import hypothesis.strategies as st | |
import pytest | |
import chromadb | |
from chromadb.api import API | |
from chromadb.config import Settings | |
import chromadb.test.property.strategies as strategies | |
import chromadb.test.property.invariants as invariants | |
from chromadb.test.property.test_embeddings import ( | |
EmbeddingStateMachine, | |
EmbeddingStateMachineStates, | |
) | |
from hypothesis.stateful import run_state_machine_as_test, rule, precondition | |
import os | |
import shutil | |
import tempfile | |
CreatePersistAPI = Callable[[], API] | |
configurations = [ | |
Settings( | |
chroma_api_impl="local", | |
chroma_db_impl="duckdb+parquet", | |
persist_directory=tempfile.gettempdir() + "/tests", | |
) | |
] | |
def settings(request: pytest.FixtureRequest) -> Generator[Settings, None, None]: | |
configuration = request.param | |
yield configuration | |
save_path = configuration.persist_directory | |
# Remove if it exists | |
if os.path.exists(save_path): | |
shutil.rmtree(save_path) | |
collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll") | |
def test_persist( | |
settings: Settings, | |
collection_strategy: strategies.Collection, | |
embeddings_strategy: strategies.RecordSet, | |
) -> None: | |
api_1 = chromadb.Client(settings) | |
api_1.reset() | |
coll = api_1.create_collection( | |
name=collection_strategy.name, | |
metadata=collection_strategy.metadata, | |
embedding_function=collection_strategy.embedding_function, | |
) | |
coll.add(**embeddings_strategy) | |
invariants.count(coll, embeddings_strategy) | |
invariants.metadatas_match(coll, embeddings_strategy) | |
invariants.documents_match(coll, embeddings_strategy) | |
invariants.ids_match(coll, embeddings_strategy) | |
invariants.ann_accuracy( | |
coll, | |
embeddings_strategy, | |
embedding_function=collection_strategy.embedding_function, | |
) | |
api_1.persist() | |
del api_1 | |
api_2 = chromadb.Client(settings) | |
coll = api_2.get_collection( | |
name=collection_strategy.name, | |
embedding_function=collection_strategy.embedding_function, | |
) | |
invariants.count(coll, embeddings_strategy) | |
invariants.metadatas_match(coll, embeddings_strategy) | |
invariants.documents_match(coll, embeddings_strategy) | |
invariants.ids_match(coll, embeddings_strategy) | |
invariants.ann_accuracy( | |
coll, | |
embeddings_strategy, | |
embedding_function=collection_strategy.embedding_function, | |
) | |
def load_and_check( | |
settings: Settings, | |
collection_name: str, | |
record_set: strategies.RecordSet, | |
conn: Connection, | |
) -> None: | |
try: | |
api = chromadb.Client(settings) | |
coll = api.get_collection( | |
name=collection_name, | |
embedding_function=strategies.not_implemented_embedding_function(), | |
) | |
invariants.count(coll, record_set) | |
invariants.metadatas_match(coll, record_set) | |
invariants.documents_match(coll, record_set) | |
invariants.ids_match(coll, record_set) | |
invariants.ann_accuracy(coll, record_set) | |
except Exception as e: | |
conn.send(e) | |
raise e | |
class PersistEmbeddingsStateMachineStates(EmbeddingStateMachineStates): | |
persist = "persist" | |
class PersistEmbeddingsStateMachine(EmbeddingStateMachine): | |
def __init__(self, api: API, settings: Settings): | |
self.api = api | |
self.settings = settings | |
self.last_persist_delay = 10 | |
self.api.reset() | |
super().__init__(self.api) | |
def persist(self) -> None: | |
self.on_state_change(PersistEmbeddingsStateMachineStates.persist) | |
self.api.persist() | |
collection_name = self.collection.name | |
# Create a new process and then inside the process run the invariants | |
# TODO: Once we switch off of duckdb and onto sqlite we can remove this | |
ctx = multiprocessing.get_context("spawn") | |
conn1, conn2 = multiprocessing.Pipe() | |
p = ctx.Process( | |
target=load_and_check, | |
args=(self.settings, collection_name, self.record_set_state, conn2), | |
) | |
p.start() | |
p.join() | |
if conn1.poll(): | |
e = conn1.recv() | |
raise e | |
def on_state_change(self, new_state: str) -> None: | |
if new_state == PersistEmbeddingsStateMachineStates.persist: | |
self.last_persist_delay = 10 | |
else: | |
self.last_persist_delay -= 1 | |
def test_persist_embeddings_state( | |
caplog: pytest.LogCaptureFixture, settings: Settings | |
) -> None: | |
caplog.set_level(logging.ERROR) | |
api = chromadb.Client(settings) | |
run_state_machine_as_test( | |
lambda: PersistEmbeddingsStateMachine(settings=settings, api=api) | |
) # type: ignore | |