Spaces:
Paused
Paused
"""Wrapper around ChromaDB (version 3) embeddings platform.""" | |
from __future__ import annotations | |
import logging | |
import uuid | |
from typing import ( | |
TYPE_CHECKING, | |
Any, | |
Callable, | |
Dict, | |
Iterable, | |
List, | |
Optional, | |
Tuple, | |
Type, | |
) | |
import numpy as np | |
from langchain.docstore.document import Document | |
from langchain.embeddings.base import Embeddings | |
from langchain.utils import xor_args | |
from langchain.vectorstores.base import VectorStore | |
from langchain.vectorstores.utils import maximal_marginal_relevance | |
if TYPE_CHECKING: | |
import chromamigdb | |
import chromamigdb.config | |
from chromamigdb.api.types import ID, OneOrMany, Where, WhereDocument | |
logger = logging.getLogger() | |
DEFAULT_K = 4 # Number of Documents to return. | |
def _results_to_docs(results: Any) -> List[Document]: | |
return [doc for doc, _ in _results_to_docs_and_scores(results)] | |
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]: | |
return [ | |
# TODO: Chroma can do batch querying, | |
# we shouldn't hard code to the 1st result | |
(Document(page_content=result[0], metadata=result[1] or {}), result[2]) | |
for result in zip( | |
results["documents"][0], | |
results["metadatas"][0], | |
results["distances"][0], | |
) | |
] | |
class ChromaMig(VectorStore): | |
"""Wrapper around ChromaDB embeddings platform. | |
To use, you should have the ``chromamigdb`` python package installed. | |
Example: | |
.. code-block:: python | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
embeddings = OpenAIEmbeddings() | |
vectorstore = Chroma("langchain_store", embeddings) | |
""" | |
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" | |
def __init__( | |
self, | |
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, | |
embedding_function: Optional[Embeddings] = None, | |
persist_directory: Optional[str] = None, | |
client_settings: Optional[chromamigdb.config.Settings] = None, | |
collection_metadata: Optional[Dict] = None, | |
client: Optional[chromamigdb.Client] = None, | |
relevance_score_fn: Optional[Callable[[float], float]] = None, | |
) -> None: | |
"""Initialize with Chroma client.""" | |
try: | |
import chromamigdb | |
import chromamigdb.config | |
except ImportError: | |
raise ValueError( | |
"Could not import chromamigdb python package. " | |
"Please install it with `pip install chromamigdb`." | |
) | |
if client is not None: | |
self._client_settings = client_settings | |
self._client = client | |
self._persist_directory = persist_directory | |
else: | |
if client_settings: | |
# If client_settings is provided with persist_directory specified, | |
# then it is "in-memory and persisting to disk" mode. | |
client_settings.persist_directory = ( | |
persist_directory or client_settings.persist_directory | |
) | |
if client_settings.persist_directory is not None: | |
# Maintain backwards compatibility with chromamigdb < 0.4.0 | |
major, minor, _ = chromamigdb.__version__.split(".") | |
if int(major) == 0 and int(minor) < 4: | |
client_settings.chroma_db_impl = "duckdb+parquet" | |
_client_settings = client_settings | |
elif persist_directory: | |
# Maintain backwards compatibility with chromamigdb < 0.4.0 | |
major, minor, _ = chromamigdb.__version__.split(".") | |
if int(major) == 0 and int(minor) < 4: | |
_client_settings = chromamigdb.config.Settings( | |
chroma_db_impl="duckdb+parquet", | |
) | |
else: | |
_client_settings = chromamigdb.config.Settings(is_persistent=True) | |
_client_settings.persist_directory = persist_directory | |
else: | |
_client_settings = chromamigdb.config.Settings() | |
self._client_settings = _client_settings | |
self._client = chromamigdb.Client(_client_settings) | |
self._persist_directory = ( | |
_client_settings.persist_directory or persist_directory | |
) | |
self._embedding_function = embedding_function | |
self._collection = self._client.get_or_create_collection( | |
name=collection_name, | |
embedding_function=self._embedding_function.embed_documents | |
if self._embedding_function is not None | |
else None, | |
metadata=collection_metadata, | |
) | |
self.override_relevance_score_fn = relevance_score_fn | |
def embeddings(self) -> Optional[Embeddings]: | |
return self._embedding_function | |
def __query_collection( | |
self, | |
query_texts: Optional[List[str]] = None, | |
query_embeddings: Optional[List[List[float]]] = None, | |
n_results: int = 4, | |
where: Optional[Dict[str, str]] = None, | |
**kwargs: Any, | |
) -> List[Document]: | |
"""Query the chroma collection.""" | |
try: | |
import chromamigdb # noqa: F401 | |
except ImportError: | |
raise ValueError( | |
"Could not import chromamigdb python package. " | |
"Please install it with `pip install chromamigdb`." | |
) | |
return self._collection.query( | |
query_texts=query_texts, | |
query_embeddings=query_embeddings, | |
n_results=n_results, | |
where=where, | |
**kwargs, | |
) | |
def add_texts( | |
self, | |
texts: Iterable[str], | |
metadatas: Optional[List[dict]] = None, | |
ids: Optional[List[str]] = None, | |
**kwargs: Any, | |
) -> List[str]: | |
"""Run more texts through the embeddings and add to the vectorstore. | |
Args: | |
texts (Iterable[str]): Texts to add to the vectorstore. | |
metadatas (Optional[List[dict]], optional): Optional list of metadatas. | |
ids (Optional[List[str]], optional): Optional list of IDs. | |
Returns: | |
List[str]: List of IDs of the added texts. | |
""" | |
# TODO: Handle the case where the user doesn't provide ids on the Collection | |
if ids is None: | |
ids = [str(uuid.uuid1()) for _ in texts] | |
embeddings = None | |
texts = list(texts) | |
if self._embedding_function is not None: | |
embeddings = self._embedding_function.embed_documents(texts) | |
if metadatas: | |
# fill metadatas with empty dicts if somebody | |
# did not specify metadata for all texts | |
length_diff = len(texts) - len(metadatas) | |
if length_diff: | |
metadatas = metadatas + [{}] * length_diff | |
empty_ids = [] | |
non_empty_ids = [] | |
for idx, m in enumerate(metadatas): | |
if m: | |
non_empty_ids.append(idx) | |
else: | |
empty_ids.append(idx) | |
if non_empty_ids: | |
metadatas = [metadatas[idx] for idx in non_empty_ids] | |
texts_with_metadatas = [texts[idx] for idx in non_empty_ids] | |
embeddings_with_metadatas = ( | |
[embeddings[idx] for idx in non_empty_ids] if embeddings else None | |
) | |
ids_with_metadata = [ids[idx] for idx in non_empty_ids] | |
self._collection.upsert( | |
metadatas=metadatas, | |
embeddings=embeddings_with_metadatas, | |
documents=texts_with_metadatas, | |
ids=ids_with_metadata, | |
) | |
if empty_ids: | |
texts_without_metadatas = [texts[j] for j in empty_ids] | |
embeddings_without_metadatas = ( | |
[embeddings[j] for j in empty_ids] if embeddings else None | |
) | |
ids_without_metadatas = [ids[j] for j in empty_ids] | |
self._collection.upsert( | |
embeddings=embeddings_without_metadatas, | |
documents=texts_without_metadatas, | |
ids=ids_without_metadatas, | |
) | |
else: | |
self._collection.upsert( | |
embeddings=embeddings, | |
documents=texts, | |
ids=ids, | |
) | |
return ids | |
def similarity_search( | |
self, | |
query: str, | |
k: int = DEFAULT_K, | |
filter: Optional[Dict[str, str]] = None, | |
**kwargs: Any, | |
) -> List[Document]: | |
"""Run similarity search with Chroma. | |
Args: | |
query (str): Query text to search for. | |
k (int): Number of results to return. Defaults to 4. | |
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. | |
Returns: | |
List[Document]: List of documents most similar to the query text. | |
""" | |
docs_and_scores = self.similarity_search_with_score(query, k, filter=filter) | |
return [doc for doc, _ in docs_and_scores] | |
def similarity_search_by_vector( | |
self, | |
embedding: List[float], | |
k: int = DEFAULT_K, | |
filter: Optional[Dict[str, str]] = None, | |
**kwargs: Any, | |
) -> List[Document]: | |
"""Return docs most similar to embedding vector. | |
Args: | |
embedding (List[float]): Embedding to look up documents similar to. | |
k (int): Number of Documents to return. Defaults to 4. | |
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. | |
Returns: | |
List of Documents most similar to the query vector. | |
""" | |
results = self.__query_collection( | |
query_embeddings=embedding, n_results=k, where=filter | |
) | |
return _results_to_docs(results) | |
def similarity_search_by_vector_with_relevance_scores( | |
self, | |
embedding: List[float], | |
k: int = DEFAULT_K, | |
filter: Optional[Dict[str, str]] = None, | |
**kwargs: Any, | |
) -> List[Tuple[Document, float]]: | |
""" | |
Return docs most similar to embedding vector and similarity score. | |
Args: | |
embedding (List[float]): Embedding to look up documents similar to. | |
k (int): Number of Documents to return. Defaults to 4. | |
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. | |
Returns: | |
List[Tuple[Document, float]]: List of documents most similar to | |
the query text and cosine distance in float for each. | |
Lower score represents more similarity. | |
""" | |
results = self.__query_collection( | |
query_embeddings=embedding, n_results=k, where=filter | |
) | |
return _results_to_docs_and_scores(results) | |
def similarity_search_with_score( | |
self, | |
query: str, | |
k: int = DEFAULT_K, | |
filter: Optional[Dict[str, str]] = None, | |
**kwargs: Any, | |
) -> List[Tuple[Document, float]]: | |
"""Run similarity search with Chroma with distance. | |
Args: | |
query (str): Query text to search for. | |
k (int): Number of results to return. Defaults to 4. | |
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. | |
Returns: | |
List[Tuple[Document, float]]: List of documents most similar to | |
the query text and cosine distance in float for each. | |
Lower score represents more similarity. | |
""" | |
if self._embedding_function is None: | |
results = self.__query_collection( | |
query_texts=[query], n_results=k, where=filter | |
) | |
else: | |
query_embedding = self._embedding_function.embed_query(query) | |
results = self.__query_collection( | |
query_embeddings=[query_embedding], n_results=k, where=filter | |
) | |
return _results_to_docs_and_scores(results) | |
def _select_relevance_score_fn(self) -> Callable[[float], float]: | |
""" | |
The 'correct' relevance function | |
may differ depending on a few things, including: | |
- the distance / similarity metric used by the VectorStore | |
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!) | |
- embedding dimensionality | |
- etc. | |
""" | |
if self.override_relevance_score_fn: | |
return self.override_relevance_score_fn | |
distance = "l2" | |
distance_key = "hnsw:space" | |
metadata = self._collection.metadata | |
if metadata and distance_key in metadata: | |
distance = metadata[distance_key] | |
if distance == "cosine": | |
return self._cosine_relevance_score_fn | |
elif distance == "l2": | |
return self._euclidean_relevance_score_fn | |
elif distance == "ip": | |
return self._max_inner_product_relevance_score_fn | |
else: | |
raise ValueError( | |
"No supported normalization function" | |
f" for distance metric of type: {distance}." | |
"Consider providing relevance_score_fn to Chroma constructor." | |
) | |
def max_marginal_relevance_search_by_vector( | |
self, | |
embedding: List[float], | |
k: int = DEFAULT_K, | |
fetch_k: int = 20, | |
lambda_mult: float = 0.5, | |
filter: Optional[Dict[str, str]] = None, | |
**kwargs: Any, | |
) -> List[Document]: | |
"""Return docs selected using the maximal marginal relevance. | |
Maximal marginal relevance optimizes for similarity to query AND diversity | |
among selected documents. | |
Args: | |
embedding: Embedding to look up documents similar to. | |
k: Number of Documents to return. Defaults to 4. | |
fetch_k: Number of Documents to fetch to pass to MMR algorithm. | |
lambda_mult: Number between 0 and 1 that determines the degree | |
of diversity among the results with 0 corresponding | |
to maximum diversity and 1 to minimum diversity. | |
Defaults to 0.5. | |
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. | |
Returns: | |
List of Documents selected by maximal marginal relevance. | |
""" | |
results = self.__query_collection( | |
query_embeddings=embedding, | |
n_results=fetch_k, | |
where=filter, | |
include=["metadatas", "documents", "distances", "embeddings"], | |
) | |
mmr_selected = maximal_marginal_relevance( | |
np.array(embedding, dtype=np.float32), | |
results["embeddings"][0], | |
k=k, | |
lambda_mult=lambda_mult, | |
) | |
candidates = _results_to_docs(results) | |
selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected] | |
return selected_results | |
def max_marginal_relevance_search( | |
self, | |
query: str, | |
k: int = DEFAULT_K, | |
fetch_k: int = 20, | |
lambda_mult: float = 0.5, | |
filter: Optional[Dict[str, str]] = None, | |
**kwargs: Any, | |
) -> List[Document]: | |
"""Return docs selected using the maximal marginal relevance. | |
Maximal marginal relevance optimizes for similarity to query AND diversity | |
among selected documents. | |
Args: | |
query: Text to look up documents similar to. | |
k: Number of Documents to return. Defaults to 4. | |
fetch_k: Number of Documents to fetch to pass to MMR algorithm. | |
lambda_mult: Number between 0 and 1 that determines the degree | |
of diversity among the results with 0 corresponding | |
to maximum diversity and 1 to minimum diversity. | |
Defaults to 0.5. | |
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. | |
Returns: | |
List of Documents selected by maximal marginal relevance. | |
""" | |
if self._embedding_function is None: | |
raise ValueError( | |
"For MMR search, you must specify an embedding function on" "creation." | |
) | |
embedding = self._embedding_function.embed_query(query) | |
docs = self.max_marginal_relevance_search_by_vector( | |
embedding, k, fetch_k, lambda_mult=lambda_mult, filter=filter | |
) | |
return docs | |
def delete_collection(self) -> None: | |
"""Delete the collection.""" | |
self._client.delete_collection(self._collection.name) | |
def get( | |
self, | |
ids: Optional[OneOrMany[ID]] = None, | |
where: Optional[Where] = None, | |
limit: Optional[int] = None, | |
offset: Optional[int] = None, | |
where_document: Optional[WhereDocument] = None, | |
include: Optional[List[str]] = None, | |
) -> Dict[str, Any]: | |
"""Gets the collection. | |
Args: | |
ids: The ids of the embeddings to get. Optional. | |
where: A Where type dict used to filter results by. | |
E.g. `{"color" : "red", "price": 4.20}`. Optional. | |
limit: The number of documents to return. Optional. | |
offset: The offset to start returning results from. | |
Useful for paging results with limit. Optional. | |
where_document: A WhereDocument type dict used to filter by the documents. | |
E.g. `{$contains: {"text": "hello"}}`. Optional. | |
include: A list of what to include in the results. | |
Can contain `"embeddings"`, `"metadatas"`, `"documents"`. | |
Ids are always included. | |
Defaults to `["metadatas", "documents"]`. Optional. | |
""" | |
kwargs = { | |
"ids": ids, | |
"where": where, | |
"limit": limit, | |
"offset": offset, | |
"where_document": where_document, | |
} | |
if include is not None: | |
kwargs["include"] = include | |
return self._collection.get(**kwargs) | |
def persist(self) -> None: | |
"""Persist the collection. | |
This can be used to explicitly persist the data to disk. | |
It will also be called automatically when the object is destroyed. | |
""" | |
if self._persist_directory is None: | |
raise ValueError( | |
"You must specify a persist_directory on" | |
"creation to persist the collection." | |
) | |
import chromamigdb | |
# Maintain backwards compatibility with chromamigdb < 0.4.0 | |
major, minor, _ = chromamigdb.__version__.split(".") | |
if int(major) == 0 and int(minor) < 4: | |
self._client.persist() | |
def update_document(self, document_id: str, document: Document) -> None: | |
"""Update a document in the collection. | |
Args: | |
document_id (str): ID of the document to update. | |
document (Document): Document to update. | |
""" | |
text = document.page_content | |
metadata = document.metadata | |
if self._embedding_function is None: | |
raise ValueError( | |
"For update, you must specify an embedding function on creation." | |
) | |
embeddings = self._embedding_function.embed_documents([text]) | |
self._collection.update( | |
ids=[document_id], | |
embeddings=embeddings, | |
documents=[text], | |
metadatas=[metadata], | |
) | |
def from_texts( | |
cls: Type[ChromaMig], | |
texts: List[str], | |
embedding: Optional[Embeddings] = None, | |
metadatas: Optional[List[dict]] = None, | |
ids: Optional[List[str]] = None, | |
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, | |
persist_directory: Optional[str] = None, | |
client_settings: Optional[chromamigdb.config.Settings] = None, | |
client: Optional[chromamigdb.Client] = None, | |
collection_metadata: Optional[Dict] = None, | |
**kwargs: Any, | |
) -> ChromaMig: | |
"""Create a Chroma vectorstore from a raw documents. | |
If a persist_directory is specified, the collection will be persisted there. | |
Otherwise, the data will be ephemeral in-memory. | |
Args: | |
texts (List[str]): List of texts to add to the collection. | |
collection_name (str): Name of the collection to create. | |
persist_directory (Optional[str]): Directory to persist the collection. | |
embedding (Optional[Embeddings]): Embedding function. Defaults to None. | |
metadatas (Optional[List[dict]]): List of metadatas. Defaults to None. | |
ids (Optional[List[str]]): List of document IDs. Defaults to None. | |
client_settings (Optional[chromamigdb.config.Settings]): Chroma client settings | |
collection_metadata (Optional[Dict]): Collection configurations. | |
Defaults to None. | |
Returns: | |
Chroma: Chroma vectorstore. | |
""" | |
chroma_collection = cls( | |
collection_name=collection_name, | |
embedding_function=embedding, | |
persist_directory=persist_directory, | |
client_settings=client_settings, | |
client=client, | |
collection_metadata=collection_metadata, | |
**kwargs, | |
) | |
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids) | |
return chroma_collection | |
def from_documents( | |
cls: Type[ChromaMig], | |
documents: List[Document], | |
embedding: Optional[Embeddings] = None, | |
ids: Optional[List[str]] = None, | |
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, | |
persist_directory: Optional[str] = None, | |
client_settings: Optional[chromamigdb.config.Settings] = None, | |
client: Optional[chromamigdb.Client] = None, # Add this line | |
collection_metadata: Optional[Dict] = None, | |
**kwargs: Any, | |
) -> ChromaMig: | |
"""Create a Chroma vectorstore from a list of documents. | |
If a persist_directory is specified, the collection will be persisted there. | |
Otherwise, the data will be ephemeral in-memory. | |
Args: | |
collection_name (str): Name of the collection to create. | |
persist_directory (Optional[str]): Directory to persist the collection. | |
ids (Optional[List[str]]): List of document IDs. Defaults to None. | |
documents (List[Document]): List of documents to add to the vectorstore. | |
embedding (Optional[Embeddings]): Embedding function. Defaults to None. | |
client_settings (Optional[chromamigdb.config.Settings]): Chroma client settings | |
collection_metadata (Optional[Dict]): Collection configurations. | |
Defaults to None. | |
Returns: | |
Chroma: Chroma vectorstore. | |
""" | |
texts = [doc.page_content for doc in documents] | |
metadatas = [doc.metadata for doc in documents] | |
return cls.from_texts( | |
texts=texts, | |
embedding=embedding, | |
metadatas=metadatas, | |
ids=ids, | |
collection_name=collection_name, | |
persist_directory=persist_directory, | |
client_settings=client_settings, | |
client=client, | |
collection_metadata=collection_metadata, | |
**kwargs, | |
) | |
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None: | |
"""Delete by vector IDs. | |
Args: | |
ids: List of ids to delete. | |
""" | |
self._collection.delete(ids=ids) | |