SungBeom's picture
Upload folder using huggingface_hub
4a51346
raw
history blame
15 kB
import pytest
from typing import Generator, List, Callable, Iterator, Dict, Optional, Union, Sequence
from chromadb.config import System, Settings
from chromadb.types import (
SubmitEmbeddingRecord,
MetadataEmbeddingRecord,
Operation,
ScalarEncoding,
Segment,
SegmentScope,
SeqId,
)
from chromadb.ingest import Producer
from chromadb.segment import MetadataReader
import uuid
import time
from chromadb.segment.impl.metadata.sqlite import SqliteMetadataSegment
from pytest import FixtureRequest
from itertools import count
def sqlite() -> Generator[System, None, None]:
"""Fixture generator for sqlite DB"""
settings = Settings(sqlite_database=":memory:", allow_reset=True)
system = System(settings)
system.start()
yield system
system.stop()
def system_fixtures() -> List[Callable[[], Generator[System, None, None]]]:
return [sqlite]
@pytest.fixture(scope="module", params=system_fixtures())
def system(request: FixtureRequest) -> Generator[System, None, None]:
yield next(request.param())
@pytest.fixture(scope="function")
def sample_embeddings() -> Iterator[SubmitEmbeddingRecord]:
def create_record(i: int) -> SubmitEmbeddingRecord:
vector = [i + i * 0.1, i + 1 + i * 0.1]
metadata: Optional[Dict[str, Union[str, int, float]]]
if i == 0:
metadata = None
else:
metadata = {"str_key": f"value_{i}", "int_key": i, "float_key": i + i * 0.1}
if i % 3 == 0:
metadata["div_by_three"] = "true"
metadata["document"] = _build_document(i)
record = SubmitEmbeddingRecord(
id=f"embedding_{i}",
embedding=vector,
encoding=ScalarEncoding.FLOAT32,
metadata=metadata,
operation=Operation.ADD,
)
return record
return (create_record(i) for i in count())
_digit_map = {
"0": "zero",
"1": "one",
"2": "two",
"3": "three",
"4": "four",
"5": "five",
"6": "six",
"7": "seven",
"8": "eight",
"9": "nine",
}
def _build_document(i: int) -> str:
digits = list(str(i))
return " ".join(_digit_map[d] for d in digits)
segment_definition = Segment(
id=uuid.uuid4(),
type="test_type",
scope=SegmentScope.METADATA,
topic="persistent://test/test/test_topic_1",
collection=None,
metadata=None,
)
def sync(segment: MetadataReader, seq_id: SeqId) -> None:
# Try for up to 5 seconds, then throw a TimeoutError
start = time.time()
while time.time() - start < 5:
if segment.max_seqid() >= seq_id:
return
time.sleep(0.25)
raise TimeoutError(f"Timed out waiting for seq_id {seq_id}")
def test_insert_and_count(
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
) -> None:
system.reset()
producer = system.instance(Producer)
topic = str(segment_definition["topic"])
max_id = 0
for i in range(3):
max_id = producer.submit_embedding(topic, next(sample_embeddings))
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
sync(segment, max_id)
assert segment.count() == 3
for i in range(3):
max_id = producer.submit_embedding(topic, next(sample_embeddings))
sync(segment, max_id)
assert segment.count() == 6
def assert_equiv_records(
expected: Sequence[SubmitEmbeddingRecord], actual: Sequence[MetadataEmbeddingRecord]
) -> None:
assert len(expected) == len(actual)
sorted_expected = sorted(expected, key=lambda r: r["id"])
sorted_actual = sorted(actual, key=lambda r: r["id"])
for e, a in zip(sorted_expected, sorted_actual):
assert e["id"] == a["id"]
assert e["metadata"] == a["metadata"]
def test_get(
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
) -> None:
system.reset()
producer = system.instance(Producer)
topic = str(segment_definition["topic"])
embeddings = [next(sample_embeddings) for i in range(10)]
seq_ids = []
for e in embeddings:
seq_ids.append(producer.submit_embedding(topic, e))
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
sync(segment, seq_ids[-1])
# Get all records
results = segment.get_metadata()
assert seq_ids == [r["seq_id"] for r in results]
assert_equiv_records(embeddings, results)
# get by ID
result = segment.get_metadata(ids=[e["id"] for e in embeddings[0:5]])
assert_equiv_records(embeddings[0:5], result)
# Get with limit and offset
# Cannot rely on order(yet), but can rely on retrieving exactly the
# whole set eventually
ret: List[MetadataEmbeddingRecord] = []
ret.extend(segment.get_metadata(limit=3))
assert len(ret) == 3
ret.extend(segment.get_metadata(limit=3, offset=3))
assert len(ret) == 6
ret.extend(segment.get_metadata(limit=3, offset=6))
assert len(ret) == 9
ret.extend(segment.get_metadata(limit=3, offset=9))
assert len(ret) == 10
assert_equiv_records(embeddings, ret)
# Get with simple where
result = segment.get_metadata(where={"div_by_three": "true"})
assert len(result) == 3
# Get with gt/gte/lt/lte on int keys
result = segment.get_metadata(where={"int_key": {"$gt": 5}})
assert len(result) == 4
result = segment.get_metadata(where={"int_key": {"$gte": 5}})
assert len(result) == 5
result = segment.get_metadata(where={"int_key": {"$lt": 5}})
assert len(result) == 4
result = segment.get_metadata(where={"int_key": {"$lte": 5}})
assert len(result) == 5
# Get with gt/lt on float keys with float values
result = segment.get_metadata(where={"float_key": {"$gt": 5.01}})
assert len(result) == 5
result = segment.get_metadata(where={"float_key": {"$lt": 4.99}})
assert len(result) == 4
# Get with gt/lt on float keys with int values
result = segment.get_metadata(where={"float_key": {"$gt": 5}})
assert len(result) == 5
result = segment.get_metadata(where={"float_key": {"$lt": 5}})
assert len(result) == 4
# Get with gt/lt on int keys with float values
result = segment.get_metadata(where={"int_key": {"$gt": 5.01}})
assert len(result) == 4
result = segment.get_metadata(where={"int_key": {"$lt": 4.99}})
assert len(result) == 4
# Get with $ne
# Returns metadata that has an int_key, but not equal to 5
result = segment.get_metadata(where={"int_key": {"$ne": 5}})
assert len(result) == 8
# get with multiple heterogenous conditions
result = segment.get_metadata(where={"div_by_three": "true", "int_key": {"$gt": 5}})
assert len(result) == 2
# get with OR conditions
result = segment.get_metadata(where={"$or": [{"int_key": 1}, {"int_key": 2}]})
assert len(result) == 2
# get with AND conditions
result = segment.get_metadata(
where={"$and": [{"int_key": 3}, {"float_key": {"$gt": 5}}]}
)
assert len(result) == 0
result = segment.get_metadata(
where={"$and": [{"int_key": 3}, {"float_key": {"$lt": 5}}]}
)
assert len(result) == 1
def test_fulltext(
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
) -> None:
system.reset()
producer = system.instance(Producer)
topic = str(segment_definition["topic"])
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
max_id = 0
for i in range(100):
max_id = producer.submit_embedding(topic, next(sample_embeddings))
sync(segment, max_id)
result = segment.get_metadata(where={"document": "four two"})
result2 = segment.get_metadata(ids=["embedding_42"])
assert result == result2
# Test single result
result = segment.get_metadata(where_document={"$contains": "four two"})
assert len(result) == 1
# Test many results
result = segment.get_metadata(where_document={"$contains": "zero"})
assert len(result) == 9
# test $and
result = segment.get_metadata(
where_document={"$and": [{"$contains": "four"}, {"$contains": "two"}]}
)
assert len(result) == 2
assert set([r["id"] for r in result]) == {"embedding_42", "embedding_24"}
# test $or
result = segment.get_metadata(
where_document={"$or": [{"$contains": "zero"}, {"$contains": "one"}]}
)
ones = [i for i in range(1, 100) if "one" in _build_document(i)]
zeros = [i for i in range(1, 100) if "zero" in _build_document(i)]
expected = set([f"embedding_{i}" for i in set(ones + zeros)])
assert set([r["id"] for r in result]) == expected
# test combo with where clause (negative case)
result = segment.get_metadata(
where={"int_key": {"$eq": 42}}, where_document={"$contains": "zero"}
)
assert len(result) == 0
# test combo with where clause (positive case)
result = segment.get_metadata(
where={"int_key": {"$eq": 42}}, where_document={"$contains": "four"}
)
assert len(result) == 1
# test partial words
result = segment.get_metadata(where_document={"$contains": "zer"})
assert len(result) == 9
def test_delete(
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
) -> None:
system.reset()
producer = system.instance(Producer)
topic = str(segment_definition["topic"])
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
embeddings = [next(sample_embeddings) for i in range(10)]
max_id = 0
for e in embeddings:
max_id = producer.submit_embedding(topic, e)
sync(segment, max_id)
assert segment.count() == 10
results = segment.get_metadata(ids=["embedding_0"])
assert_equiv_records(embeddings[:1], results)
# Delete by ID
max_id = producer.submit_embedding(
topic,
SubmitEmbeddingRecord(
id="embedding_0",
embedding=None,
encoding=None,
metadata=None,
operation=Operation.DELETE,
),
)
sync(segment, max_id)
assert segment.count() == 9
assert segment.get_metadata(ids=["embedding_0"]) == []
# Delete is idempotent
max_id = producer.submit_embedding(
topic,
SubmitEmbeddingRecord(
id="embedding_0",
embedding=None,
encoding=None,
metadata=None,
operation=Operation.DELETE,
),
)
sync(segment, max_id)
assert segment.count() == 9
assert segment.get_metadata(ids=["embedding_0"]) == []
# re-add
max_id = producer.submit_embedding(topic, embeddings[0])
sync(segment, max_id)
assert segment.count() == 10
results = segment.get_metadata(ids=["embedding_0"])
def test_update(
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
) -> None:
system.reset()
producer = system.instance(Producer)
topic = str(segment_definition["topic"])
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
_test_update(sample_embeddings, producer, segment, topic, Operation.UPDATE)
# Update nonexisting ID
update_record = SubmitEmbeddingRecord(
id="no_such_id",
metadata={"foo": "bar"},
embedding=None,
encoding=None,
operation=Operation.UPDATE,
)
max_id = producer.submit_embedding(topic, update_record)
sync(segment, max_id)
results = segment.get_metadata(ids=["no_such_id"])
assert len(results) == 0
assert segment.count() == 3
def test_upsert(
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
) -> None:
system.reset()
producer = system.instance(Producer)
topic = str(segment_definition["topic"])
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
_test_update(sample_embeddings, producer, segment, topic, Operation.UPSERT)
# upsert previously nonexisting ID
update_record = SubmitEmbeddingRecord(
id="no_such_id",
metadata={"foo": "bar"},
embedding=None,
encoding=None,
operation=Operation.UPSERT,
)
max_id = producer.submit_embedding(topic, update_record)
sync(segment, max_id)
results = segment.get_metadata(ids=["no_such_id"])
assert results[0]["metadata"] == {"foo": "bar"}
def _test_update(
sample_embeddings: Iterator[SubmitEmbeddingRecord],
producer: Producer,
segment: MetadataReader,
topic: str,
op: Operation,
) -> None:
"""test code common between update and upsert paths"""
embeddings = [next(sample_embeddings) for i in range(3)]
max_id = 0
for e in embeddings:
max_id = producer.submit_embedding(topic, e)
sync(segment, max_id)
results = segment.get_metadata(ids=["embedding_0"])
assert_equiv_records(embeddings[:1], results)
# Update embedding with no metadata
update_record = SubmitEmbeddingRecord(
id="embedding_0",
metadata={"document": "foo bar"},
embedding=None,
encoding=None,
operation=op,
)
max_id = producer.submit_embedding(topic, update_record)
sync(segment, max_id)
results = segment.get_metadata(ids=["embedding_0"])
assert results[0]["metadata"] == {"document": "foo bar"}
results = segment.get_metadata(where_document={"$contains": "foo"})
assert results[0]["metadata"] == {"document": "foo bar"}
# Update and overrwrite key
update_record = SubmitEmbeddingRecord(
id="embedding_0",
metadata={"document": "biz buz"},
embedding=None,
encoding=None,
operation=op,
)
max_id = producer.submit_embedding(topic, update_record)
sync(segment, max_id)
results = segment.get_metadata(ids=["embedding_0"])
assert results[0]["metadata"] == {"document": "biz buz"}
results = segment.get_metadata(where_document={"$contains": "biz"})
assert results[0]["metadata"] == {"document": "biz buz"}
results = segment.get_metadata(where_document={"$contains": "foo"})
assert len(results) == 0
# Update and add key
update_record = SubmitEmbeddingRecord(
id="embedding_0",
metadata={"baz": 42},
embedding=None,
encoding=None,
operation=op,
)
max_id = producer.submit_embedding(topic, update_record)
sync(segment, max_id)
results = segment.get_metadata(ids=["embedding_0"])
assert results[0]["metadata"] == {"document": "biz buz", "baz": 42}
# Update and delete key
update_record = SubmitEmbeddingRecord(
id="embedding_0",
metadata={"document": None},
embedding=None,
encoding=None,
operation=op,
)
max_id = producer.submit_embedding(topic, update_record)
sync(segment, max_id)
results = segment.get_metadata(ids=["embedding_0"])
assert results[0]["metadata"] == {"baz": 42}
results = segment.get_metadata(where_document={"$contains": "biz"})
assert len(results) == 0