Spaces:
Runtime error
Runtime error
import hashlib | |
import hypothesis | |
import hypothesis.strategies as st | |
from typing import Any, Optional, List, Dict, Union | |
from typing_extensions import TypedDict | |
import numpy as np | |
import numpy.typing as npt | |
import chromadb.api.types as types | |
import re | |
from hypothesis.strategies._internal.strategies import SearchStrategy | |
from hypothesis.errors import InvalidDefinition | |
from hypothesis.stateful import RuleBasedStateMachine | |
from dataclasses import dataclass | |
from chromadb.api.types import Documents, Embeddings, Metadata | |
# Set the random seed for reproducibility | |
np.random.seed(0) # unnecessary, hypothesis does this for us | |
# See Hypothesis documentation for creating strategies at | |
# https://hypothesis.readthedocs.io/en/latest/data.html | |
# NOTE: Because these strategies are used in state machines, we need to | |
# work around an issue with state machines, in which strategies that frequently | |
# are marked as invalid (i.e. through the use of `assume` or `.filter`) can cause the | |
# state machine tests to fail with an hypothesis.errors.Unsatisfiable. | |
# Ultimately this is because the entire state machine is run as a single Hypothesis | |
# example, which ends up drawing from the same strategies an enormous number of times. | |
# Whenever a strategy marks itself as invalid, Hypothesis tries to start the entire | |
# state machine run over. See https://github.com/HypothesisWorks/hypothesis/issues/3618 | |
# Because strategy generation is all interrelated, seemingly small changes (especially | |
# ones called early in a test) can have an outside effect. Generating lists with | |
# unique=True, or dictionaries with a min size seems especially bad. | |
# Please make changes to these strategies incrementally, testing to make sure they don't | |
# start generating unsatisfiable examples. | |
test_hnsw_config = { | |
"hnsw:construction_ef": 128, | |
"hnsw:search_ef": 128, | |
"hnsw:M": 128, | |
} | |
class RecordSet(TypedDict): | |
""" | |
A generated set of embeddings, ids, metadatas, and documents that | |
represent what a user would pass to the API. | |
""" | |
ids: Union[types.ID, List[types.ID]] | |
embeddings: Optional[Union[types.Embeddings, types.Embedding]] | |
metadatas: Optional[Union[List[types.Metadata], types.Metadata]] | |
documents: Optional[Union[List[types.Document], types.Document]] | |
class NormalizedRecordSet(TypedDict): | |
""" | |
A RecordSet, with all fields normalized to lists. | |
""" | |
ids: List[types.ID] | |
embeddings: Optional[types.Embeddings] | |
metadatas: Optional[List[types.Metadata]] | |
documents: Optional[List[types.Document]] | |
class StateMachineRecordSet(TypedDict): | |
""" | |
Represents the internal state of a state machine in hypothesis tests. | |
""" | |
ids: List[types.ID] | |
embeddings: types.Embeddings | |
metadatas: List[Optional[types.Metadata]] | |
documents: List[Optional[types.Document]] | |
class Record(TypedDict): | |
""" | |
A single generated record. | |
""" | |
id: types.ID | |
embedding: Optional[types.Embedding] | |
metadata: Optional[types.Metadata] | |
document: Optional[types.Document] | |
# TODO: support arbitrary text everywhere so we don't SQL-inject ourselves. | |
# TODO: support empty strings everywhere | |
sql_alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_" | |
safe_text = st.text(alphabet=sql_alphabet, min_size=1) | |
# Workaround for FastAPI json encoding peculiarities | |
# https://github.com/tiangolo/fastapi/blob/8ac8d70d52bb0dd9eb55ba4e22d3e383943da05c/fastapi/encoders.py#L104 | |
safe_text = safe_text.filter(lambda s: not s.startswith("_sa")) | |
safe_integers = st.integers( | |
min_value=-(2**31), max_value=2**31 - 1 | |
) # TODO: handle longs | |
safe_floats = st.floats( | |
allow_infinity=False, | |
allow_nan=False, | |
allow_subnormal=False, | |
min_value=-1e6, | |
max_value=1e6, | |
) # TODO: handle infinity and NAN | |
safe_values: List[SearchStrategy[Union[int, float, str]]] = [ | |
safe_text, | |
safe_integers, | |
safe_floats, | |
] | |
def one_or_both( | |
strategy_a: st.SearchStrategy[Any], strategy_b: st.SearchStrategy[Any] | |
) -> st.SearchStrategy[Any]: | |
return st.one_of( | |
st.tuples(strategy_a, strategy_b), | |
st.tuples(strategy_a, st.none()), | |
st.tuples(st.none(), strategy_b), | |
) | |
# Temporarily generate only these to avoid SQL formatting issues. | |
legal_id_characters = ( | |
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_./+" | |
) | |
float_types = [np.float16, np.float32, np.float64] | |
int_types = [np.int16, np.int32, np.int64] # TODO: handle int types | |
def collection_name(draw: st.DrawFn) -> str: | |
_collection_name_re = re.compile(r"^[a-zA-Z][a-zA-Z0-9-]{1,60}[a-zA-Z0-9]$") | |
_ipv4_address_re = re.compile(r"^([0-9]{1,3}\.){3}[0-9]{1,3}$") | |
_two_periods_re = re.compile(r"\.\.") | |
name: str = draw(st.from_regex(_collection_name_re)) | |
hypothesis.assume(not _ipv4_address_re.match(name)) | |
hypothesis.assume(not _two_periods_re.search(name)) | |
return name | |
collection_metadata = st.one_of( | |
st.none(), st.dictionaries(safe_text, st.one_of(*safe_values)) | |
) | |
# TODO: Use a hypothesis strategy while maintaining embedding uniqueness | |
# Or handle duplicate embeddings within a known epsilon | |
def create_embeddings( | |
dim: int, | |
count: int, | |
dtype: npt.DTypeLike, | |
) -> types.Embeddings: | |
embeddings: types.Embeddings = ( | |
np.random.uniform( | |
low=-1.0, | |
high=1.0, | |
size=(count, dim), | |
) | |
.astype(dtype) | |
.tolist() | |
) | |
return embeddings | |
class hashing_embedding_function(types.EmbeddingFunction): | |
def __init__(self, dim: int, dtype: npt.DTypeLike) -> None: | |
self.dim = dim | |
self.dtype = dtype | |
def __call__(self, texts: types.Documents) -> types.Embeddings: | |
# Hash the texts and convert to hex strings | |
hashed_texts = [ | |
list(hashlib.sha256(text.encode("utf-8")).hexdigest()) for text in texts | |
] | |
# Pad with repetition, or truncate the hex strings to the desired dimension | |
padded_texts = [ | |
text * (self.dim // len(text)) + text[: self.dim % len(text)] | |
for text in hashed_texts | |
] | |
# Convert the hex strings to dtype | |
embeddings: types.Embeddings = np.array( | |
[[int(char, 16) / 15.0 for char in text] for text in padded_texts], | |
dtype=self.dtype, | |
).tolist() | |
return embeddings | |
class not_implemented_embedding_function(types.EmbeddingFunction): | |
def __call__(self, texts: Documents) -> Embeddings: | |
assert False, "This embedding function is not implemented" | |
def embedding_function_strategy( | |
dim: int, dtype: npt.DTypeLike | |
) -> st.SearchStrategy[types.EmbeddingFunction]: | |
return st.just(hashing_embedding_function(dim, dtype)) | |
class Collection: | |
name: str | |
metadata: Optional[types.Metadata] | |
dimension: int | |
dtype: npt.DTypeLike | |
known_metadata_keys: types.Metadata | |
known_document_keywords: List[str] | |
has_documents: bool = False | |
has_embeddings: bool = False | |
embedding_function: Optional[types.EmbeddingFunction] = None | |
def collections( | |
draw: st.DrawFn, | |
add_filterable_data: bool = False, | |
with_hnsw_params: bool = False, | |
has_embeddings: Optional[bool] = None, | |
has_documents: Optional[bool] = None, | |
) -> Collection: | |
"""Strategy to generate a Collection object. If add_filterable_data is True, then known_metadata_keys and known_document_keywords will be populated with consistent data.""" | |
assert not ((has_embeddings is False) and (has_documents is False)) | |
name = draw(collection_name()) | |
metadata = draw(collection_metadata) | |
dimension = draw(st.integers(min_value=2, max_value=2048)) | |
dtype = draw(st.sampled_from(float_types)) | |
if with_hnsw_params: | |
if metadata is None: | |
metadata = {} | |
metadata.update(test_hnsw_config) | |
# Sometimes, select a space at random | |
if draw(st.booleans()): | |
# TODO: pull the distance functions from a source of truth that lives not | |
# in tests once https://github.com/chroma-core/issues/issues/61 lands | |
metadata["hnsw:space"] = draw(st.sampled_from(["cosine", "l2", "ip"])) | |
known_metadata_keys: Dict[str, Union[int, str, float]] = {} | |
if add_filterable_data: | |
while len(known_metadata_keys) < 5: | |
key = draw(safe_text) | |
known_metadata_keys[key] = draw(st.one_of(*safe_values)) | |
if has_documents is None: | |
has_documents = draw(st.booleans()) | |
assert has_documents is not None | |
if has_documents and add_filterable_data: | |
known_document_keywords = draw(st.lists(safe_text, min_size=5, max_size=5)) | |
else: | |
known_document_keywords = [] | |
if not has_documents: | |
has_embeddings = True | |
else: | |
if has_embeddings is None: | |
has_embeddings = draw(st.booleans()) | |
assert has_embeddings is not None | |
embedding_function = draw(embedding_function_strategy(dimension, dtype)) | |
return Collection( | |
name=name, | |
metadata=metadata, | |
dimension=dimension, | |
dtype=dtype, | |
known_metadata_keys=known_metadata_keys, | |
has_documents=has_documents, | |
known_document_keywords=known_document_keywords, | |
has_embeddings=has_embeddings, | |
embedding_function=embedding_function, | |
) | |
def metadata(draw: st.DrawFn, collection: Collection) -> types.Metadata: | |
"""Strategy for generating metadata that could be a part of the given collection""" | |
# First draw a random dictionary. | |
metadata: types.Metadata = draw(st.dictionaries(safe_text, st.one_of(*safe_values))) | |
# Then, remove keys that overlap with the known keys for the coll | |
# to avoid type errors when comparing. | |
if collection.known_metadata_keys: | |
for key in collection.known_metadata_keys.keys(): | |
if key in metadata: | |
del metadata[key] | |
# Finally, add in some of the known keys for the collection | |
sampling_dict: Dict[str, st.SearchStrategy[Union[str, int, float]]] = { | |
k: st.just(v) for k, v in collection.known_metadata_keys.items() | |
} | |
metadata.update(draw(st.fixed_dictionaries({}, optional=sampling_dict))) | |
return metadata | |
def document(draw: st.DrawFn, collection: Collection) -> types.Document: | |
"""Strategy for generating documents that could be a part of the given collection""" | |
if collection.known_document_keywords: | |
known_words_st = st.sampled_from(collection.known_document_keywords) | |
else: | |
known_words_st = st.text(min_size=1) | |
random_words_st = st.text(min_size=1) | |
words = draw(st.lists(st.one_of(known_words_st, random_words_st), min_size=1)) | |
return " ".join(words) | |
def recordsets( | |
draw: st.DrawFn, | |
collection_strategy: SearchStrategy[Collection] = collections(), | |
id_strategy: SearchStrategy[str] = safe_text, | |
min_size: int = 1, | |
max_size: int = 50, | |
) -> RecordSet: | |
collection = draw(collection_strategy) | |
ids = list( | |
draw(st.lists(id_strategy, min_size=min_size, max_size=max_size, unique=True)) | |
) | |
embeddings: Optional[Embeddings] = None | |
if collection.has_embeddings: | |
embeddings = create_embeddings(collection.dimension, len(ids), collection.dtype) | |
metadatas = draw( | |
st.lists(metadata(collection), min_size=len(ids), max_size=len(ids)) | |
) | |
documents: Optional[Documents] = None | |
if collection.has_documents: | |
documents = draw( | |
st.lists(document(collection), min_size=len(ids), max_size=len(ids)) | |
) | |
# in the case where we have a single record, sometimes exercise | |
# the code that handles individual values rather than lists. | |
# In this case, any field may be a list or a single value. | |
if len(ids) == 1: | |
single_id: Union[str, List[str]] = ids[0] if draw(st.booleans()) else ids | |
single_embedding = ( | |
embeddings[0] | |
if embeddings is not None and draw(st.booleans()) | |
else embeddings | |
) | |
single_metadata: Union[Metadata, List[Metadata]] = ( | |
metadatas[0] if draw(st.booleans()) else metadatas | |
) | |
single_document = ( | |
documents[0] if documents is not None and draw(st.booleans()) else documents | |
) | |
return { | |
"ids": single_id, | |
"embeddings": single_embedding, | |
"metadatas": single_metadata, | |
"documents": single_document, | |
} | |
return { | |
"ids": ids, | |
"embeddings": embeddings, | |
"metadatas": metadatas, | |
"documents": documents, | |
} | |
# This class is mostly cloned from from hypothesis.stateful.RuleStrategy, | |
# but always runs all the rules, instead of using a FeatureStrategy to | |
# enable/disable rules. Disabled rules cause the entire test to be marked invalida and, | |
# combined with the complexity of our other strategies, leads to an | |
# unacceptably increased incidence of hypothesis.errors.Unsatisfiable. | |
class DeterministicRuleStrategy(SearchStrategy): # type: ignore | |
def __init__(self, machine: RuleBasedStateMachine) -> None: | |
super().__init__() # type: ignore | |
self.machine = machine | |
self.rules = list(machine.rules()) # type: ignore | |
# The order is a bit arbitrary. Primarily we're trying to group rules | |
# that write to the same location together, and to put rules with no | |
# target first as they have less effect on the structure. We order from | |
# fewer to more arguments on grounds that it will plausibly need less | |
# data. This probably won't work especially well and we could be | |
# smarter about it, but it's better than just doing it in definition | |
# order. | |
self.rules.sort( | |
key=lambda rule: ( | |
sorted(rule.targets), | |
len(rule.arguments), | |
rule.function.__name__, | |
) | |
) | |
def __repr__(self) -> str: | |
return "{}(machine={}({{...}}))".format( | |
self.__class__.__name__, | |
self.machine.__class__.__name__, | |
) | |
def do_draw(self, data): # type: ignore | |
if not any(self.is_valid(rule) for rule in self.rules): | |
msg = f"No progress can be made from state {self.machine!r}" | |
raise InvalidDefinition(msg) from None | |
rule = data.draw(st.sampled_from([r for r in self.rules if self.is_valid(r)])) | |
argdata = data.draw(rule.arguments_strategy) | |
return (rule, argdata) | |
def is_valid(self, rule) -> bool: # type: ignore | |
if not all(precond(self.machine) for precond in rule.preconditions): | |
return False | |
for b in rule.bundles: | |
bundle = self.machine.bundle(b.name) # type: ignore | |
if not bundle: | |
return False | |
return True | |
def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where: | |
"""Generate a filter that could be used in a query against the given collection""" | |
known_keys = sorted(collection.known_metadata_keys.keys()) | |
key = draw(st.sampled_from(known_keys)) | |
value = collection.known_metadata_keys[key] | |
legal_ops: List[Optional[str]] = [None, "$eq", "$ne"] | |
if not isinstance(value, str): | |
legal_ops.extend(["$gt", "$lt", "$lte", "$gte"]) | |
if isinstance(value, float): | |
# Add or subtract a small number to avoid floating point rounding errors | |
value = value + draw(st.sampled_from([1e-6, -1e-6])) | |
op: types.WhereOperator = draw(st.sampled_from(legal_ops)) | |
if op is None: | |
return {key: value} | |
else: | |
return {key: {op: value}} | |
def where_doc_clause(draw: st.DrawFn, collection: Collection) -> types.WhereDocument: | |
"""Generate a where_document filter that could be used against the given collection""" | |
if collection.known_document_keywords: | |
word = draw(st.sampled_from(collection.known_document_keywords)) | |
else: | |
word = draw(safe_text) | |
return {"$contains": word} | |
def binary_operator_clause( | |
base_st: SearchStrategy[types.Where], | |
) -> SearchStrategy[types.Where]: | |
op: SearchStrategy[types.LogicalOperator] = st.sampled_from(["$and", "$or"]) | |
return st.dictionaries( | |
keys=op, | |
values=st.lists(base_st, max_size=2, min_size=2), | |
min_size=1, | |
max_size=1, | |
) | |
def binary_document_operator_clause( | |
base_st: SearchStrategy[types.WhereDocument], | |
) -> SearchStrategy[types.WhereDocument]: | |
op: SearchStrategy[types.LogicalOperator] = st.sampled_from(["$and", "$or"]) | |
return st.dictionaries( | |
keys=op, | |
values=st.lists(base_st, max_size=2, min_size=2), | |
min_size=1, | |
max_size=1, | |
) | |
def recursive_where_clause(draw: st.DrawFn, collection: Collection) -> types.Where: | |
base_st = where_clause(collection) | |
where: types.Where = draw(st.recursive(base_st, binary_operator_clause)) | |
return where | |
def recursive_where_doc_clause( | |
draw: st.DrawFn, collection: Collection | |
) -> types.WhereDocument: | |
base_st = where_doc_clause(collection) | |
where: types.WhereDocument = draw( | |
st.recursive(base_st, binary_document_operator_clause) | |
) | |
return where | |
class Filter(TypedDict): | |
where: Optional[types.Where] | |
ids: Optional[Union[str, List[str]]] | |
where_document: Optional[types.WhereDocument] | |
def filters( | |
draw: st.DrawFn, | |
collection_st: st.SearchStrategy[Collection], | |
recordset_st: st.SearchStrategy[RecordSet], | |
include_all_ids: bool = False, | |
) -> Filter: | |
collection = draw(collection_st) | |
recordset = draw(recordset_st) | |
where_clause = draw(st.one_of(st.none(), recursive_where_clause(collection))) | |
where_document_clause = draw( | |
st.one_of(st.none(), recursive_where_doc_clause(collection)) | |
) | |
ids: Optional[Union[List[types.ID], types.ID]] | |
# Record sets can be a value instead of a list of values if there is only one record | |
if isinstance(recordset["ids"], str): | |
ids = [recordset["ids"]] | |
else: | |
ids = recordset["ids"] | |
if not include_all_ids: | |
ids = draw(st.one_of(st.none(), st.lists(st.sampled_from(ids)))) | |
if ids is not None: | |
# Remove duplicates since hypothesis samples with replacement | |
ids = list(set(ids)) | |
# Test both the single value list and the unwrapped single value case | |
if ids is not None and len(ids) == 1 and draw(st.booleans()): | |
ids = ids[0] | |
return {"where": where_clause, "where_document": where_document_clause, "ids": ids} | |