import hashlib import hypothesis import hypothesis.strategies as st from typing import Any, Optional, List, Dict, Union from typing_extensions import TypedDict import numpy as np import numpy.typing as npt import chromadb.api.types as types import re from hypothesis.strategies._internal.strategies import SearchStrategy from hypothesis.errors import InvalidDefinition from hypothesis.stateful import RuleBasedStateMachine from dataclasses import dataclass from chromadb.api.types import Documents, Embeddings, Metadata # Set the random seed for reproducibility np.random.seed(0) # unnecessary, hypothesis does this for us # See Hypothesis documentation for creating strategies at # https://hypothesis.readthedocs.io/en/latest/data.html # NOTE: Because these strategies are used in state machines, we need to # work around an issue with state machines, in which strategies that frequently # are marked as invalid (i.e. through the use of `assume` or `.filter`) can cause the # state machine tests to fail with an hypothesis.errors.Unsatisfiable. # Ultimately this is because the entire state machine is run as a single Hypothesis # example, which ends up drawing from the same strategies an enormous number of times. # Whenever a strategy marks itself as invalid, Hypothesis tries to start the entire # state machine run over. See https://github.com/HypothesisWorks/hypothesis/issues/3618 # Because strategy generation is all interrelated, seemingly small changes (especially # ones called early in a test) can have an outside effect. Generating lists with # unique=True, or dictionaries with a min size seems especially bad. # Please make changes to these strategies incrementally, testing to make sure they don't # start generating unsatisfiable examples. test_hnsw_config = { "hnsw:construction_ef": 128, "hnsw:search_ef": 128, "hnsw:M": 128, } class RecordSet(TypedDict): """ A generated set of embeddings, ids, metadatas, and documents that represent what a user would pass to the API. """ ids: Union[types.ID, List[types.ID]] embeddings: Optional[Union[types.Embeddings, types.Embedding]] metadatas: Optional[Union[List[types.Metadata], types.Metadata]] documents: Optional[Union[List[types.Document], types.Document]] class NormalizedRecordSet(TypedDict): """ A RecordSet, with all fields normalized to lists. """ ids: List[types.ID] embeddings: Optional[types.Embeddings] metadatas: Optional[List[types.Metadata]] documents: Optional[List[types.Document]] class StateMachineRecordSet(TypedDict): """ Represents the internal state of a state machine in hypothesis tests. """ ids: List[types.ID] embeddings: types.Embeddings metadatas: List[Optional[types.Metadata]] documents: List[Optional[types.Document]] class Record(TypedDict): """ A single generated record. """ id: types.ID embedding: Optional[types.Embedding] metadata: Optional[types.Metadata] document: Optional[types.Document] # TODO: support arbitrary text everywhere so we don't SQL-inject ourselves. # TODO: support empty strings everywhere sql_alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_" safe_text = st.text(alphabet=sql_alphabet, min_size=1) # Workaround for FastAPI json encoding peculiarities # https://github.com/tiangolo/fastapi/blob/8ac8d70d52bb0dd9eb55ba4e22d3e383943da05c/fastapi/encoders.py#L104 safe_text = safe_text.filter(lambda s: not s.startswith("_sa")) safe_integers = st.integers( min_value=-(2**31), max_value=2**31 - 1 ) # TODO: handle longs safe_floats = st.floats( allow_infinity=False, allow_nan=False, allow_subnormal=False, min_value=-1e6, max_value=1e6, ) # TODO: handle infinity and NAN safe_values: List[SearchStrategy[Union[int, float, str]]] = [ safe_text, safe_integers, safe_floats, ] def one_or_both( strategy_a: st.SearchStrategy[Any], strategy_b: st.SearchStrategy[Any] ) -> st.SearchStrategy[Any]: return st.one_of( st.tuples(strategy_a, strategy_b), st.tuples(strategy_a, st.none()), st.tuples(st.none(), strategy_b), ) # Temporarily generate only these to avoid SQL formatting issues. legal_id_characters = ( "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_./+" ) float_types = [np.float16, np.float32, np.float64] int_types = [np.int16, np.int32, np.int64] # TODO: handle int types @st.composite def collection_name(draw: st.DrawFn) -> str: _collection_name_re = re.compile(r"^[a-zA-Z][a-zA-Z0-9-]{1,60}[a-zA-Z0-9]$") _ipv4_address_re = re.compile(r"^([0-9]{1,3}\.){3}[0-9]{1,3}$") _two_periods_re = re.compile(r"\.\.") name: str = draw(st.from_regex(_collection_name_re)) hypothesis.assume(not _ipv4_address_re.match(name)) hypothesis.assume(not _two_periods_re.search(name)) return name collection_metadata = st.one_of( st.none(), st.dictionaries(safe_text, st.one_of(*safe_values)) ) # TODO: Use a hypothesis strategy while maintaining embedding uniqueness # Or handle duplicate embeddings within a known epsilon def create_embeddings( dim: int, count: int, dtype: npt.DTypeLike, ) -> types.Embeddings: embeddings: types.Embeddings = ( np.random.uniform( low=-1.0, high=1.0, size=(count, dim), ) .astype(dtype) .tolist() ) return embeddings class hashing_embedding_function(types.EmbeddingFunction): def __init__(self, dim: int, dtype: npt.DTypeLike) -> None: self.dim = dim self.dtype = dtype def __call__(self, texts: types.Documents) -> types.Embeddings: # Hash the texts and convert to hex strings hashed_texts = [ list(hashlib.sha256(text.encode("utf-8")).hexdigest()) for text in texts ] # Pad with repetition, or truncate the hex strings to the desired dimension padded_texts = [ text * (self.dim // len(text)) + text[: self.dim % len(text)] for text in hashed_texts ] # Convert the hex strings to dtype embeddings: types.Embeddings = np.array( [[int(char, 16) / 15.0 for char in text] for text in padded_texts], dtype=self.dtype, ).tolist() return embeddings class not_implemented_embedding_function(types.EmbeddingFunction): def __call__(self, texts: Documents) -> Embeddings: assert False, "This embedding function is not implemented" def embedding_function_strategy( dim: int, dtype: npt.DTypeLike ) -> st.SearchStrategy[types.EmbeddingFunction]: return st.just(hashing_embedding_function(dim, dtype)) @dataclass class Collection: name: str metadata: Optional[types.Metadata] dimension: int dtype: npt.DTypeLike known_metadata_keys: types.Metadata known_document_keywords: List[str] has_documents: bool = False has_embeddings: bool = False embedding_function: Optional[types.EmbeddingFunction] = None @st.composite def collections( draw: st.DrawFn, add_filterable_data: bool = False, with_hnsw_params: bool = False, has_embeddings: Optional[bool] = None, has_documents: Optional[bool] = None, ) -> Collection: """Strategy to generate a Collection object. If add_filterable_data is True, then known_metadata_keys and known_document_keywords will be populated with consistent data.""" assert not ((has_embeddings is False) and (has_documents is False)) name = draw(collection_name()) metadata = draw(collection_metadata) dimension = draw(st.integers(min_value=2, max_value=2048)) dtype = draw(st.sampled_from(float_types)) if with_hnsw_params: if metadata is None: metadata = {} metadata.update(test_hnsw_config) # Sometimes, select a space at random if draw(st.booleans()): # TODO: pull the distance functions from a source of truth that lives not # in tests once https://github.com/chroma-core/issues/issues/61 lands metadata["hnsw:space"] = draw(st.sampled_from(["cosine", "l2", "ip"])) known_metadata_keys: Dict[str, Union[int, str, float]] = {} if add_filterable_data: while len(known_metadata_keys) < 5: key = draw(safe_text) known_metadata_keys[key] = draw(st.one_of(*safe_values)) if has_documents is None: has_documents = draw(st.booleans()) assert has_documents is not None if has_documents and add_filterable_data: known_document_keywords = draw(st.lists(safe_text, min_size=5, max_size=5)) else: known_document_keywords = [] if not has_documents: has_embeddings = True else: if has_embeddings is None: has_embeddings = draw(st.booleans()) assert has_embeddings is not None embedding_function = draw(embedding_function_strategy(dimension, dtype)) return Collection( name=name, metadata=metadata, dimension=dimension, dtype=dtype, known_metadata_keys=known_metadata_keys, has_documents=has_documents, known_document_keywords=known_document_keywords, has_embeddings=has_embeddings, embedding_function=embedding_function, ) @st.composite def metadata(draw: st.DrawFn, collection: Collection) -> types.Metadata: """Strategy for generating metadata that could be a part of the given collection""" # First draw a random dictionary. metadata: types.Metadata = draw(st.dictionaries(safe_text, st.one_of(*safe_values))) # Then, remove keys that overlap with the known keys for the coll # to avoid type errors when comparing. if collection.known_metadata_keys: for key in collection.known_metadata_keys.keys(): if key in metadata: del metadata[key] # Finally, add in some of the known keys for the collection sampling_dict: Dict[str, st.SearchStrategy[Union[str, int, float]]] = { k: st.just(v) for k, v in collection.known_metadata_keys.items() } metadata.update(draw(st.fixed_dictionaries({}, optional=sampling_dict))) return metadata @st.composite def document(draw: st.DrawFn, collection: Collection) -> types.Document: """Strategy for generating documents that could be a part of the given collection""" if collection.known_document_keywords: known_words_st = st.sampled_from(collection.known_document_keywords) else: known_words_st = st.text(min_size=1) random_words_st = st.text(min_size=1) words = draw(st.lists(st.one_of(known_words_st, random_words_st), min_size=1)) return " ".join(words) @st.composite def recordsets( draw: st.DrawFn, collection_strategy: SearchStrategy[Collection] = collections(), id_strategy: SearchStrategy[str] = safe_text, min_size: int = 1, max_size: int = 50, ) -> RecordSet: collection = draw(collection_strategy) ids = list( draw(st.lists(id_strategy, min_size=min_size, max_size=max_size, unique=True)) ) embeddings: Optional[Embeddings] = None if collection.has_embeddings: embeddings = create_embeddings(collection.dimension, len(ids), collection.dtype) metadatas = draw( st.lists(metadata(collection), min_size=len(ids), max_size=len(ids)) ) documents: Optional[Documents] = None if collection.has_documents: documents = draw( st.lists(document(collection), min_size=len(ids), max_size=len(ids)) ) # in the case where we have a single record, sometimes exercise # the code that handles individual values rather than lists. # In this case, any field may be a list or a single value. if len(ids) == 1: single_id: Union[str, List[str]] = ids[0] if draw(st.booleans()) else ids single_embedding = ( embeddings[0] if embeddings is not None and draw(st.booleans()) else embeddings ) single_metadata: Union[Metadata, List[Metadata]] = ( metadatas[0] if draw(st.booleans()) else metadatas ) single_document = ( documents[0] if documents is not None and draw(st.booleans()) else documents ) return { "ids": single_id, "embeddings": single_embedding, "metadatas": single_metadata, "documents": single_document, } return { "ids": ids, "embeddings": embeddings, "metadatas": metadatas, "documents": documents, } # This class is mostly cloned from from hypothesis.stateful.RuleStrategy, # but always runs all the rules, instead of using a FeatureStrategy to # enable/disable rules. Disabled rules cause the entire test to be marked invalida and, # combined with the complexity of our other strategies, leads to an # unacceptably increased incidence of hypothesis.errors.Unsatisfiable. class DeterministicRuleStrategy(SearchStrategy): # type: ignore def __init__(self, machine: RuleBasedStateMachine) -> None: super().__init__() # type: ignore self.machine = machine self.rules = list(machine.rules()) # type: ignore # The order is a bit arbitrary. Primarily we're trying to group rules # that write to the same location together, and to put rules with no # target first as they have less effect on the structure. We order from # fewer to more arguments on grounds that it will plausibly need less # data. This probably won't work especially well and we could be # smarter about it, but it's better than just doing it in definition # order. self.rules.sort( key=lambda rule: ( sorted(rule.targets), len(rule.arguments), rule.function.__name__, ) ) def __repr__(self) -> str: return "{}(machine={}({{...}}))".format( self.__class__.__name__, self.machine.__class__.__name__, ) def do_draw(self, data): # type: ignore if not any(self.is_valid(rule) for rule in self.rules): msg = f"No progress can be made from state {self.machine!r}" raise InvalidDefinition(msg) from None rule = data.draw(st.sampled_from([r for r in self.rules if self.is_valid(r)])) argdata = data.draw(rule.arguments_strategy) return (rule, argdata) def is_valid(self, rule) -> bool: # type: ignore if not all(precond(self.machine) for precond in rule.preconditions): return False for b in rule.bundles: bundle = self.machine.bundle(b.name) # type: ignore if not bundle: return False return True @st.composite def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where: """Generate a filter that could be used in a query against the given collection""" known_keys = sorted(collection.known_metadata_keys.keys()) key = draw(st.sampled_from(known_keys)) value = collection.known_metadata_keys[key] legal_ops: List[Optional[str]] = [None, "$eq", "$ne"] if not isinstance(value, str): legal_ops.extend(["$gt", "$lt", "$lte", "$gte"]) if isinstance(value, float): # Add or subtract a small number to avoid floating point rounding errors value = value + draw(st.sampled_from([1e-6, -1e-6])) op: types.WhereOperator = draw(st.sampled_from(legal_ops)) if op is None: return {key: value} else: return {key: {op: value}} @st.composite def where_doc_clause(draw: st.DrawFn, collection: Collection) -> types.WhereDocument: """Generate a where_document filter that could be used against the given collection""" if collection.known_document_keywords: word = draw(st.sampled_from(collection.known_document_keywords)) else: word = draw(safe_text) return {"$contains": word} def binary_operator_clause( base_st: SearchStrategy[types.Where], ) -> SearchStrategy[types.Where]: op: SearchStrategy[types.LogicalOperator] = st.sampled_from(["$and", "$or"]) return st.dictionaries( keys=op, values=st.lists(base_st, max_size=2, min_size=2), min_size=1, max_size=1, ) def binary_document_operator_clause( base_st: SearchStrategy[types.WhereDocument], ) -> SearchStrategy[types.WhereDocument]: op: SearchStrategy[types.LogicalOperator] = st.sampled_from(["$and", "$or"]) return st.dictionaries( keys=op, values=st.lists(base_st, max_size=2, min_size=2), min_size=1, max_size=1, ) @st.composite def recursive_where_clause(draw: st.DrawFn, collection: Collection) -> types.Where: base_st = where_clause(collection) where: types.Where = draw(st.recursive(base_st, binary_operator_clause)) return where @st.composite def recursive_where_doc_clause( draw: st.DrawFn, collection: Collection ) -> types.WhereDocument: base_st = where_doc_clause(collection) where: types.WhereDocument = draw( st.recursive(base_st, binary_document_operator_clause) ) return where class Filter(TypedDict): where: Optional[types.Where] ids: Optional[Union[str, List[str]]] where_document: Optional[types.WhereDocument] @st.composite def filters( draw: st.DrawFn, collection_st: st.SearchStrategy[Collection], recordset_st: st.SearchStrategy[RecordSet], include_all_ids: bool = False, ) -> Filter: collection = draw(collection_st) recordset = draw(recordset_st) where_clause = draw(st.one_of(st.none(), recursive_where_clause(collection))) where_document_clause = draw( st.one_of(st.none(), recursive_where_doc_clause(collection)) ) ids: Optional[Union[List[types.ID], types.ID]] # Record sets can be a value instead of a list of values if there is only one record if isinstance(recordset["ids"], str): ids = [recordset["ids"]] else: ids = recordset["ids"] if not include_all_ids: ids = draw(st.one_of(st.none(), st.lists(st.sampled_from(ids)))) if ids is not None: # Remove duplicates since hypothesis samples with replacement ids = list(set(ids)) # Test both the single value list and the unwrapped single value case if ids is not None and len(ids) == 1 and draw(st.booleans()): ids = ids[0] return {"where": where_clause, "where_document": where_document_clause, "ids": ids}