Spaces:
Runtime error
Runtime error
import pytest | |
import logging | |
import hypothesis.strategies as st | |
import chromadb.test.property.strategies as strategies | |
from chromadb.api import API | |
import chromadb.api.types as types | |
from hypothesis.stateful import ( | |
Bundle, | |
RuleBasedStateMachine, | |
rule, | |
initialize, | |
multiple, | |
consumes, | |
run_state_machine_as_test, | |
MultipleResults, | |
) | |
from typing import Dict, Optional | |
class CollectionStateMachine(RuleBasedStateMachine): | |
collections: Bundle[strategies.Collection] | |
model: Dict[str, Optional[types.CollectionMetadata]] | |
collections = Bundle("collections") | |
def __init__(self, api: API): | |
super().__init__() | |
self.model = {} | |
self.api = api | |
def initialize(self) -> None: | |
self.api.reset() | |
self.model = {} | |
def create_coll( | |
self, coll: strategies.Collection | |
) -> MultipleResults[strategies.Collection]: | |
if coll.name in self.model: | |
with pytest.raises(Exception): | |
c = self.api.create_collection( | |
name=coll.name, | |
metadata=coll.metadata, | |
embedding_function=coll.embedding_function, | |
) | |
return multiple() | |
c = self.api.create_collection( | |
name=coll.name, | |
metadata=coll.metadata, | |
embedding_function=coll.embedding_function, | |
) | |
self.model[coll.name] = coll.metadata | |
assert c.name == coll.name | |
assert c.metadata == coll.metadata | |
return multiple(coll) | |
def get_coll(self, coll: strategies.Collection) -> None: | |
if coll.name in self.model: | |
c = self.api.get_collection(name=coll.name) | |
assert c.name == coll.name | |
assert c.metadata == coll.metadata | |
else: | |
with pytest.raises(Exception): | |
self.api.get_collection(name=coll.name) | |
def delete_coll(self, coll: strategies.Collection) -> None: | |
if coll.name in self.model: | |
self.api.delete_collection(name=coll.name) | |
del self.model[coll.name] | |
else: | |
with pytest.raises(Exception): | |
self.api.delete_collection(name=coll.name) | |
with pytest.raises(Exception): | |
self.api.get_collection(name=coll.name) | |
def list_collections(self) -> None: | |
colls = self.api.list_collections() | |
assert len(colls) == len(self.model) | |
for c in colls: | |
assert c.name in self.model | |
def get_or_create_coll( | |
self, | |
coll: strategies.Collection, | |
new_metadata: Optional[types.Metadata], | |
) -> MultipleResults[strategies.Collection]: | |
# Cases for get_or_create | |
# Case 0 | |
# new_metadata is none, coll is an existing collection | |
# get_or_create should return the existing collection with existing metadata | |
# Essentially - an update with none is a no-op | |
# Case 1 | |
# new_metadata is none, coll is a new collection | |
# get_or_create should create a new collection with the metadata of None | |
# Case 2 | |
# new_metadata is not none, coll is an existing collection | |
# get_or_create should return the existing collection with updated metadata | |
# Case 3 | |
# new_metadata is not none, coll is a new collection | |
# get_or_create should create a new collection with the new metadata, ignoring | |
# the metdata of in the input coll. | |
# The fact that we ignore the metadata of the generated collections is a | |
# bit weird, but it is the easiest way to excercise all cases | |
# Update model | |
if coll.name not in self.model: | |
# Handles case 1 and 3 | |
coll.metadata = new_metadata | |
else: | |
# Handles case 0 and 2 | |
coll.metadata = ( | |
self.model[coll.name] if new_metadata is None else new_metadata | |
) | |
self.model[coll.name] = coll.metadata | |
# Update API | |
c = self.api.get_or_create_collection( | |
name=coll.name, | |
metadata=new_metadata, | |
embedding_function=coll.embedding_function, | |
) | |
# Check that model and API are in sync | |
assert c.name == coll.name | |
assert c.metadata == coll.metadata | |
return multiple(coll) | |
def modify_coll( | |
self, | |
coll: strategies.Collection, | |
new_metadata: types.Metadata, | |
new_name: Optional[str], | |
) -> MultipleResults[strategies.Collection]: | |
if coll.name not in self.model: | |
with pytest.raises(Exception): | |
c = self.api.get_collection(name=coll.name) | |
return multiple() | |
c = self.api.get_collection(name=coll.name) | |
if new_metadata is not None: | |
coll.metadata = new_metadata | |
self.model[coll.name] = coll.metadata | |
if new_name is not None: | |
if new_name in self.model and new_name != coll.name: | |
with pytest.raises(Exception): | |
c.modify(metadata=new_metadata, name=new_name) | |
return multiple() | |
del self.model[coll.name] | |
self.model[new_name] = coll.metadata | |
coll.name = new_name | |
c.modify(metadata=new_metadata, name=new_name) | |
c = self.api.get_collection(name=coll.name) | |
assert c.name == coll.name | |
assert c.metadata == coll.metadata | |
return multiple(coll) | |
def test_collections(caplog: pytest.LogCaptureFixture, api: API) -> None: | |
caplog.set_level(logging.ERROR) | |
run_state_machine_as_test(lambda: CollectionStateMachine(api)) # type: ignore | |