import pytest from typing import Generator, List, Callable, Dict, Union from chromadb.types import Collection, Segment, SegmentScope from chromadb.db.impl.sqlite import SqliteDB from chromadb.config import System, Settings from chromadb.db.system import SysDB from chromadb.db.base import NotFoundError, UniqueConstraintError from pytest import FixtureRequest import uuid def sqlite() -> Generator[SysDB, None, None]: """Fixture generator for sqlite DB""" db = SqliteDB(System(Settings(sqlite_database=":memory:", allow_reset=True))) db.start() yield db db.stop() def db_fixtures() -> List[Callable[[], Generator[SysDB, None, None]]]: return [sqlite] @pytest.fixture(scope="module", params=db_fixtures()) def sysdb(request: FixtureRequest) -> Generator[SysDB, None, None]: yield next(request.param()) sample_collections = [ Collection( id=uuid.uuid4(), name="test_collection_1", topic="test_topic_1", metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3}, ), Collection( id=uuid.uuid4(), name="test_collection_2", topic="test_topic_2", metadata={"test_str": "str2", "test_int": 2, "test_float": 2.3}, ), Collection( id=uuid.uuid4(), name="test_collection_3", topic="test_topic_3", metadata={"test_str": "str3", "test_int": 3, "test_float": 3.3}, ), ] def test_create_get_delete_collections(sysdb: SysDB) -> None: sysdb.reset() for collection in sample_collections: sysdb.create_collection(collection) results = sysdb.get_collections() results = sorted(results, key=lambda c: c["name"]) assert sorted(results, key=lambda c: c["name"]) == sample_collections # Duplicate create fails with pytest.raises(UniqueConstraintError): sysdb.create_collection(sample_collections[0]) # Find by name for collection in sample_collections: result = sysdb.get_collections(name=collection["name"]) assert result == [collection] # Find by topic for collection in sample_collections: result = sysdb.get_collections(topic=collection["topic"]) assert result == [collection] # Find by id for collection in sample_collections: result = sysdb.get_collections(id=collection["id"]) assert result == [collection] # Find by id and topic (positive case) for collection in sample_collections: result = sysdb.get_collections(id=collection["id"], topic=collection["topic"]) assert result == [collection] # find by id and topic (negative case) for collection in sample_collections: result = sysdb.get_collections(id=collection["id"], topic="other_topic") assert result == [] # Delete c1 = sample_collections[0] sysdb.delete_collection(c1["id"]) results = sysdb.get_collections() assert c1 not in results assert len(results) == len(sample_collections) - 1 assert sorted(results, key=lambda c: c["name"]) == sample_collections[1:] by_id_result = sysdb.get_collections(id=c1["id"]) assert by_id_result == [] # Duplicate delete throws an exception with pytest.raises(NotFoundError): sysdb.delete_collection(c1["id"]) def test_update_collections(sysdb: SysDB) -> None: metadata: Dict[str, Union[str, int, float]] = { "test_str": "str1", "test_int": 1, "test_float": 1.3, } coll = Collection( id=uuid.uuid4(), name="test_collection_1", topic="test_topic_1", metadata=metadata, ) sysdb.reset() sysdb.create_collection(coll) # Update name coll["name"] = "new_name" sysdb.update_collection(coll["id"], name=coll["name"]) result = sysdb.get_collections(name=coll["name"]) assert result == [coll] # Update topic coll["topic"] = "new_topic" sysdb.update_collection(coll["id"], topic=coll["topic"]) result = sysdb.get_collections(topic=coll["topic"]) assert result == [coll] # Add a new metadata key metadata["test_str2"] = "str2" sysdb.update_collection(coll["id"], metadata={"test_str2": "str2"}) result = sysdb.get_collections(id=coll["id"]) assert result == [coll] # Update a metadata key metadata["test_str"] = "str3" sysdb.update_collection(coll["id"], metadata={"test_str": "str3"}) result = sysdb.get_collections(id=coll["id"]) assert result == [coll] # Delete a metadata key del metadata["test_str"] sysdb.update_collection(coll["id"], metadata={"test_str": None}) result = sysdb.get_collections(id=coll["id"]) assert result == [coll] # Delete all metadata keys coll["metadata"] = None sysdb.update_collection(coll["id"], metadata=None) result = sysdb.get_collections(id=coll["id"]) assert result == [coll] sample_segments = [ Segment( id=uuid.UUID("00000000-d7d7-413b-92e1-731098a6e492"), type="test_type_a", scope=SegmentScope.VECTOR, topic=None, collection=sample_collections[0]["id"], metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3}, ), Segment( id=uuid.UUID("11111111-d7d7-413b-92e1-731098a6e492"), type="test_type_b", topic="test_topic_2", scope=SegmentScope.VECTOR, collection=sample_collections[1]["id"], metadata={"test_str": "str2", "test_int": 2, "test_float": 2.3}, ), Segment( id=uuid.UUID("22222222-d7d7-413b-92e1-731098a6e492"), type="test_type_b", topic="test_topic_3", scope=SegmentScope.METADATA, collection=None, metadata={"test_str": "str3", "test_int": 3, "test_float": 3.3}, ), ] def test_create_get_delete_segments(sysdb: SysDB) -> None: sysdb.reset() for collection in sample_collections: sysdb.create_collection(collection) for segment in sample_segments: sysdb.create_segment(segment) results = sysdb.get_segments() results = sorted(results, key=lambda c: c["id"]) assert results == sample_segments # Duplicate create fails with pytest.raises(UniqueConstraintError): sysdb.create_segment(sample_segments[0]) # Find by id for segment in sample_segments: result = sysdb.get_segments(id=segment["id"]) assert result == [segment] # Find by type result = sysdb.get_segments(type="test_type_a") assert result == sample_segments[:1] result = sysdb.get_segments(type="test_type_b") assert result == sample_segments[1:] # Find by collection ID result = sysdb.get_segments(collection=sample_collections[0]["id"]) assert result == sample_segments[:1] # Find by type and collection ID (positive case) result = sysdb.get_segments( type="test_type_a", collection=sample_collections[0]["id"] ) assert result == sample_segments[:1] # Find by type and collection ID (negative case) result = sysdb.get_segments( type="test_type_b", collection=sample_collections[0]["id"] ) assert result == [] # Delete s1 = sample_segments[0] sysdb.delete_segment(s1["id"]) results = sysdb.get_segments() assert s1 not in results assert len(results) == len(sample_segments) - 1 assert sorted(results, key=lambda c: c["type"]) == sample_segments[1:] # Duplicate delete throws an exception with pytest.raises(NotFoundError): sysdb.delete_segment(s1["id"]) def test_update_segment(sysdb: SysDB) -> None: metadata: Dict[str, Union[str, int, float]] = { "test_str": "str1", "test_int": 1, "test_float": 1.3, } segment = Segment( id=uuid.uuid4(), type="test_type_a", scope=SegmentScope.VECTOR, topic="test_topic_a", collection=sample_collections[0]["id"], metadata=metadata, ) sysdb.reset() for c in sample_collections: sysdb.create_collection(c) sysdb.create_segment(segment) # Update topic to new value segment["topic"] = "new_topic" sysdb.update_segment(segment["id"], topic=segment["topic"]) result = sysdb.get_segments(id=segment["id"]) assert result == [segment] # Update topic to None segment["topic"] = None sysdb.update_segment(segment["id"], topic=segment["topic"]) result = sysdb.get_segments(id=segment["id"]) assert result == [segment] # Update collection to new value segment["collection"] = sample_collections[1]["id"] sysdb.update_segment(segment["id"], collection=segment["collection"]) result = sysdb.get_segments(id=segment["id"]) assert result == [segment] # Update collection to None segment["collection"] = None sysdb.update_segment(segment["id"], collection=segment["collection"]) result = sysdb.get_segments(id=segment["id"]) assert result == [segment] # Add a new metadata key metadata["test_str2"] = "str2" sysdb.update_segment(segment["id"], metadata={"test_str2": "str2"}) result = sysdb.get_segments(id=segment["id"]) assert result == [segment] # Update a metadata key metadata["test_str"] = "str3" sysdb.update_segment(segment["id"], metadata={"test_str": "str3"}) result = sysdb.get_segments(id=segment["id"]) assert result == [segment] # Delete a metadata key del metadata["test_str"] sysdb.update_segment(segment["id"], metadata={"test_str": None}) result = sysdb.get_segments(id=segment["id"]) assert result == [segment] # Delete all metadata keys segment["metadata"] = None sysdb.update_segment(segment["id"], metadata=None) result = sysdb.get_segments(id=segment["id"]) assert result == [segment]