import math from chromadb.test.property.strategies import NormalizedRecordSet, RecordSet from typing import Callable, Optional, Tuple, Union, List, TypeVar, cast, Dict from typing_extensions import Literal import numpy as np import numpy.typing as npt from chromadb.api import types from chromadb.api.models.Collection import Collection from hypothesis import note from hypothesis.errors import InvalidArgument T = TypeVar("T") def wrap(value: Union[T, List[T]]) -> List[T]: """Wrap a value in a list if it is not a list""" if value is None: raise InvalidArgument("value cannot be None") elif isinstance(value, List): return value else: return [value] def wrap_all(record_set: RecordSet) -> NormalizedRecordSet: """Ensure that an embedding set has lists for all its values""" embedding_list: Optional[types.Embeddings] if record_set["embeddings"] is None: embedding_list = None elif isinstance(record_set["embeddings"], list): assert record_set["embeddings"] is not None if len(record_set["embeddings"]) > 0 and not all( isinstance(embedding, list) for embedding in record_set["embeddings"] ): if all(isinstance(e, (int, float)) for e in record_set["embeddings"]): embedding_list = cast(types.Embeddings, [record_set["embeddings"]]) else: raise InvalidArgument("an embedding must be a list of floats or ints") else: embedding_list = cast(types.Embeddings, record_set["embeddings"]) else: raise InvalidArgument( "embeddings must be a list of lists, a list of numbers, or None" ) return { "ids": wrap(record_set["ids"]), "documents": wrap(record_set["documents"]) if record_set["documents"] is not None else None, "metadatas": wrap(record_set["metadatas"]) if record_set["metadatas"] is not None else None, "embeddings": embedding_list, } def count(collection: Collection, record_set: RecordSet) -> None: """The given collection count is equal to the number of embeddings""" count = collection.count() normalized_record_set = wrap_all(record_set) assert count == len(normalized_record_set["ids"]) def _field_matches( collection: Collection, normalized_record_set: NormalizedRecordSet, field_name: Union[Literal["documents"], Literal["metadatas"]], ) -> None: """ The actual embedding field is equal to the expected field field_name: one of [documents, metadatas] """ result = collection.get(ids=normalized_record_set["ids"], include=[field_name]) # The test_out_of_order_ids test fails because of this in test_add.py # Here we sort by the ids to match the input order embedding_id_to_index = {id: i for i, id in enumerate(normalized_record_set["ids"])} actual_field = result[field_name] # This assert should never happen, if we include metadatas/documents it will be # [None, None..] if there is no metadata. It will not be just None. assert actual_field is not None sorted_field = sorted( enumerate(actual_field), key=lambda index_and_field_value: embedding_id_to_index[ result["ids"][index_and_field_value[0]] ], ) field_values = [field_value for _, field_value in sorted_field] expected_field = normalized_record_set[field_name] if expected_field is None: # Since an RecordSet is the user input, we need to convert the documents to # a List since thats what the API returns -> none per entry expected_field = [None] * len(normalized_record_set["ids"]) # type: ignore assert field_values == expected_field def ids_match(collection: Collection, record_set: RecordSet) -> None: """The actual embedding ids is equal to the expected ids""" normalized_record_set = wrap_all(record_set) actual_ids = collection.get(ids=normalized_record_set["ids"], include=[])["ids"] # The test_out_of_order_ids test fails because of this in test_add.py # Here we sort the ids to match the input order embedding_id_to_index = {id: i for i, id in enumerate(normalized_record_set["ids"])} actual_ids = sorted(actual_ids, key=lambda id: embedding_id_to_index[id]) assert actual_ids == normalized_record_set["ids"] def metadatas_match(collection: Collection, record_set: RecordSet) -> None: """The actual embedding metadata is equal to the expected metadata""" normalized_record_set = wrap_all(record_set) _field_matches(collection, normalized_record_set, "metadatas") def documents_match(collection: Collection, record_set: RecordSet) -> None: """The actual embedding documents is equal to the expected documents""" normalized_record_set = wrap_all(record_set) _field_matches(collection, normalized_record_set, "documents") def no_duplicates(collection: Collection) -> None: ids = collection.get()["ids"] assert len(ids) == len(set(ids)) # These match what the spec of hnswlib is # This epsilon is used to prevent division by zero and the value is the same # https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/python_bindings/bindings.cpp#L238 NORM_EPS = 1e-30 distance_functions: Dict[str, Callable[[npt.ArrayLike, npt.ArrayLike], float]] = { "l2": lambda x, y: np.linalg.norm(x - y) ** 2, # type: ignore "cosine": lambda x, y: 1 - np.dot(x, y) / ((np.linalg.norm(x) + NORM_EPS) * (np.linalg.norm(y) + NORM_EPS)), # type: ignore "ip": lambda x, y: 1 - np.dot(x, y), # type: ignore } def _exact_distances( query: types.Embeddings, targets: types.Embeddings, distance_fn: Callable[[npt.ArrayLike, npt.ArrayLike], float] = distance_functions[ "l2" ], ) -> Tuple[List[List[int]], List[List[float]]]: """Return the ordered indices and distances from each query to each target""" np_query = np.array(query) np_targets = np.array(targets) # Compute the distance between each query and each target, using the distance function distances = np.apply_along_axis( lambda query: np.apply_along_axis(distance_fn, 1, np_targets, query), 1, np_query, ) # Sort the distances and return the indices return np.argsort(distances).tolist(), distances.tolist() def ann_accuracy( collection: Collection, record_set: RecordSet, n_results: int = 1, min_recall: float = 0.99, embedding_function: Optional[types.EmbeddingFunction] = None, ) -> None: """Validate that the API performs nearest_neighbor searches correctly""" normalized_record_set = wrap_all(record_set) if len(normalized_record_set["ids"]) == 0: return # nothing to test here embeddings: Optional[types.Embeddings] = normalized_record_set["embeddings"] have_embeddings = embeddings is not None and len(embeddings) > 0 if not have_embeddings: assert embedding_function is not None assert normalized_record_set["documents"] is not None assert isinstance(normalized_record_set["documents"], list) # Compute the embeddings for the documents embeddings = embedding_function(normalized_record_set["documents"]) # l2 is the default distance function distance_function = distance_functions["l2"] accuracy_threshold = 1e-6 assert collection.metadata is not None assert embeddings is not None if "hnsw:space" in collection.metadata: space = collection.metadata["hnsw:space"] # TODO: ip and cosine are numerically unstable in HNSW. # The higher the dimensionality, the more noise is introduced, since each float element # of the vector has noise added, which is then subsequently included in all normalization calculations. # This means that higher dimensions will have more noise, and thus more error. assert all(isinstance(e, list) for e in embeddings) dim = len(embeddings[0]) accuracy_threshold = accuracy_threshold * math.pow(10, int(math.log10(dim))) if space == "cosine": distance_function = distance_functions["cosine"] if space == "ip": distance_function = distance_functions["ip"] # Perform exact distance computation indices, distances = _exact_distances( embeddings, embeddings, distance_fn=distance_function ) query_results = collection.query( query_embeddings=normalized_record_set["embeddings"], query_texts=normalized_record_set["documents"] if not have_embeddings else None, n_results=n_results, include=["embeddings", "documents", "metadatas", "distances"], ) assert query_results["distances"] is not None assert query_results["documents"] is not None assert query_results["metadatas"] is not None assert query_results["embeddings"] is not None # Dict of ids to indices id_to_index = {id: i for i, id in enumerate(normalized_record_set["ids"])} missing = 0 for i, (indices_i, distances_i) in enumerate(zip(indices, distances)): expected_ids = np.array(normalized_record_set["ids"])[indices_i[:n_results]] missing += len(set(expected_ids) - set(query_results["ids"][i])) # For each id in the query results, find the index in the embeddings set # and assert that the embeddings are the same for j, id in enumerate(query_results["ids"][i]): # This may be because the true nth nearest neighbor didn't get returned by the ANN query unexpected_id = id not in expected_ids index = id_to_index[id] correct_distance = np.allclose( distances_i[index], query_results["distances"][i][j], atol=accuracy_threshold, ) if unexpected_id: # If the ID is unexpcted, but the distance is correct, then we # have a duplicate in the data. In this case, we should not reduce recall. if correct_distance: missing -= 1 else: continue else: assert correct_distance assert np.allclose(embeddings[index], query_results["embeddings"][i][j]) if normalized_record_set["documents"] is not None: assert ( normalized_record_set["documents"][index] == query_results["documents"][i][j] ) if normalized_record_set["metadatas"] is not None: assert ( normalized_record_set["metadatas"][index] == query_results["metadatas"][i][j] ) size = len(normalized_record_set["ids"]) recall = (size - missing) / size try: note( f"recall: {recall}, missing {missing} out of {size}, accuracy threshold {accuracy_threshold}" ) except InvalidArgument: pass # it's ok if we're running outside hypothesis assert recall >= min_recall # Ensure that the query results are sorted by distance for distance_result in query_results["distances"]: assert np.allclose(np.sort(distance_result), distance_result)