SungBeom's picture
Upload folder using huggingface_hub
4a51346
raw
history blame
6.21 kB
import pytest
import logging
import hypothesis.strategies as st
import chromadb.test.property.strategies as strategies
from chromadb.api import API
import chromadb.api.types as types
from hypothesis.stateful import (
Bundle,
RuleBasedStateMachine,
rule,
initialize,
multiple,
consumes,
run_state_machine_as_test,
MultipleResults,
)
from typing import Dict, Optional
class CollectionStateMachine(RuleBasedStateMachine):
collections: Bundle[strategies.Collection]
model: Dict[str, Optional[types.CollectionMetadata]]
collections = Bundle("collections")
def __init__(self, api: API):
super().__init__()
self.model = {}
self.api = api
@initialize()
def initialize(self) -> None:
self.api.reset()
self.model = {}
@rule(target=collections, coll=strategies.collections())
def create_coll(
self, coll: strategies.Collection
) -> MultipleResults[strategies.Collection]:
if coll.name in self.model:
with pytest.raises(Exception):
c = self.api.create_collection(
name=coll.name,
metadata=coll.metadata,
embedding_function=coll.embedding_function,
)
return multiple()
c = self.api.create_collection(
name=coll.name,
metadata=coll.metadata,
embedding_function=coll.embedding_function,
)
self.model[coll.name] = coll.metadata
assert c.name == coll.name
assert c.metadata == coll.metadata
return multiple(coll)
@rule(coll=collections)
def get_coll(self, coll: strategies.Collection) -> None:
if coll.name in self.model:
c = self.api.get_collection(name=coll.name)
assert c.name == coll.name
assert c.metadata == coll.metadata
else:
with pytest.raises(Exception):
self.api.get_collection(name=coll.name)
@rule(coll=consumes(collections))
def delete_coll(self, coll: strategies.Collection) -> None:
if coll.name in self.model:
self.api.delete_collection(name=coll.name)
del self.model[coll.name]
else:
with pytest.raises(Exception):
self.api.delete_collection(name=coll.name)
with pytest.raises(Exception):
self.api.get_collection(name=coll.name)
@rule()
def list_collections(self) -> None:
colls = self.api.list_collections()
assert len(colls) == len(self.model)
for c in colls:
assert c.name in self.model
@rule(
target=collections,
new_metadata=st.one_of(st.none(), strategies.collection_metadata),
coll=st.one_of(consumes(collections), strategies.collections()),
)
def get_or_create_coll(
self,
coll: strategies.Collection,
new_metadata: Optional[types.Metadata],
) -> MultipleResults[strategies.Collection]:
# Cases for get_or_create
# Case 0
# new_metadata is none, coll is an existing collection
# get_or_create should return the existing collection with existing metadata
# Essentially - an update with none is a no-op
# Case 1
# new_metadata is none, coll is a new collection
# get_or_create should create a new collection with the metadata of None
# Case 2
# new_metadata is not none, coll is an existing collection
# get_or_create should return the existing collection with updated metadata
# Case 3
# new_metadata is not none, coll is a new collection
# get_or_create should create a new collection with the new metadata, ignoring
# the metdata of in the input coll.
# The fact that we ignore the metadata of the generated collections is a
# bit weird, but it is the easiest way to excercise all cases
# Update model
if coll.name not in self.model:
# Handles case 1 and 3
coll.metadata = new_metadata
else:
# Handles case 0 and 2
coll.metadata = (
self.model[coll.name] if new_metadata is None else new_metadata
)
self.model[coll.name] = coll.metadata
# Update API
c = self.api.get_or_create_collection(
name=coll.name,
metadata=new_metadata,
embedding_function=coll.embedding_function,
)
# Check that model and API are in sync
assert c.name == coll.name
assert c.metadata == coll.metadata
return multiple(coll)
@rule(
target=collections,
coll=consumes(collections),
new_metadata=strategies.collection_metadata,
new_name=st.one_of(st.none(), strategies.collection_name()),
)
def modify_coll(
self,
coll: strategies.Collection,
new_metadata: types.Metadata,
new_name: Optional[str],
) -> MultipleResults[strategies.Collection]:
if coll.name not in self.model:
with pytest.raises(Exception):
c = self.api.get_collection(name=coll.name)
return multiple()
c = self.api.get_collection(name=coll.name)
if new_metadata is not None:
coll.metadata = new_metadata
self.model[coll.name] = coll.metadata
if new_name is not None:
if new_name in self.model and new_name != coll.name:
with pytest.raises(Exception):
c.modify(metadata=new_metadata, name=new_name)
return multiple()
del self.model[coll.name]
self.model[new_name] = coll.metadata
coll.name = new_name
c.modify(metadata=new_metadata, name=new_name)
c = self.api.get_collection(name=coll.name)
assert c.name == coll.name
assert c.metadata == coll.metadata
return multiple(coll)
def test_collections(caplog: pytest.LogCaptureFixture, api: API) -> None:
caplog.set_level(logging.ERROR)
run_state_machine_as_test(lambda: CollectionStateMachine(api)) # type: ignore