Spaces:
Runtime error
Runtime error
from typing import Optional, cast | |
from chromadb.api import API | |
from chromadb.config import System | |
from chromadb.api.types import ( | |
Documents, | |
Embeddings, | |
EmbeddingFunction, | |
IDs, | |
Include, | |
Metadatas, | |
Where, | |
WhereDocument, | |
GetResult, | |
QueryResult, | |
CollectionMetadata, | |
) | |
import chromadb.utils.embedding_functions as ef | |
import pandas as pd | |
import requests | |
import json | |
from typing import Sequence | |
from chromadb.api.models.Collection import Collection | |
import chromadb.errors as errors | |
from uuid import UUID | |
from chromadb.telemetry import Telemetry | |
from overrides import override | |
class FastAPI(API): | |
def __init__(self, system: System): | |
super().__init__(system) | |
url_prefix = "https" if system.settings.chroma_server_ssl_enabled else "http" | |
system.settings.require("chroma_server_host") | |
system.settings.require("chroma_server_http_port") | |
self._api_url = f"{url_prefix}://{system.settings.chroma_server_host}:{system.settings.chroma_server_http_port}/api/v1" | |
self._telemetry_client = self.require(Telemetry) | |
def heartbeat(self) -> int: | |
"""Returns the current server time in nanoseconds to check if the server is alive""" | |
resp = requests.get(self._api_url) | |
raise_chroma_error(resp) | |
return int(resp.json()["nanosecond heartbeat"]) | |
def list_collections(self) -> Sequence[Collection]: | |
"""Returns a list of all collections""" | |
resp = requests.get(self._api_url + "/collections") | |
raise_chroma_error(resp) | |
json_collections = resp.json() | |
collections = [] | |
for json_collection in json_collections: | |
collections.append(Collection(self, **json_collection)) | |
return collections | |
def create_collection( | |
self, | |
name: str, | |
metadata: Optional[CollectionMetadata] = None, | |
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), | |
get_or_create: bool = False, | |
) -> Collection: | |
"""Creates a collection""" | |
resp = requests.post( | |
self._api_url + "/collections", | |
data=json.dumps( | |
{"name": name, "metadata": metadata, "get_or_create": get_or_create} | |
), | |
) | |
raise_chroma_error(resp) | |
resp_json = resp.json() | |
return Collection( | |
client=self, | |
id=resp_json["id"], | |
name=resp_json["name"], | |
embedding_function=embedding_function, | |
metadata=resp_json["metadata"], | |
) | |
def get_collection( | |
self, | |
name: str, | |
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), | |
) -> Collection: | |
"""Returns a collection""" | |
resp = requests.get(self._api_url + "/collections/" + name) | |
raise_chroma_error(resp) | |
resp_json = resp.json() | |
return Collection( | |
client=self, | |
name=resp_json["name"], | |
id=resp_json["id"], | |
embedding_function=embedding_function, | |
metadata=resp_json["metadata"], | |
) | |
def get_or_create_collection( | |
self, | |
name: str, | |
metadata: Optional[CollectionMetadata] = None, | |
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), | |
) -> Collection: | |
"""Get a collection, or return it if it exists""" | |
return self.create_collection( | |
name, metadata, embedding_function, get_or_create=True | |
) | |
def _modify( | |
self, | |
id: UUID, | |
new_name: Optional[str] = None, | |
new_metadata: Optional[CollectionMetadata] = None, | |
) -> None: | |
"""Updates a collection""" | |
resp = requests.put( | |
self._api_url + "/collections/" + str(id), | |
data=json.dumps({"new_metadata": new_metadata, "new_name": new_name}), | |
) | |
raise_chroma_error(resp) | |
def delete_collection(self, name: str) -> None: | |
"""Deletes a collection""" | |
resp = requests.delete(self._api_url + "/collections/" + name) | |
raise_chroma_error(resp) | |
def _count(self, collection_id: UUID) -> int: | |
"""Returns the number of embeddings in the database""" | |
resp = requests.get( | |
self._api_url + "/collections/" + str(collection_id) + "/count" | |
) | |
raise_chroma_error(resp) | |
return cast(int, resp.json()) | |
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: | |
return self._get( | |
collection_id, | |
limit=n, | |
include=["embeddings", "documents", "metadatas"], | |
) | |
def _get( | |
self, | |
collection_id: UUID, | |
ids: Optional[IDs] = None, | |
where: Optional[Where] = {}, | |
sort: Optional[str] = None, | |
limit: Optional[int] = None, | |
offset: Optional[int] = None, | |
page: Optional[int] = None, | |
page_size: Optional[int] = None, | |
where_document: Optional[WhereDocument] = {}, | |
include: Include = ["metadatas", "documents"], | |
) -> GetResult: | |
"""Gets embeddings from the database""" | |
if page and page_size: | |
offset = (page - 1) * page_size | |
limit = page_size | |
resp = requests.post( | |
self._api_url + "/collections/" + str(collection_id) + "/get", | |
data=json.dumps( | |
{ | |
"ids": ids, | |
"where": where, | |
"sort": sort, | |
"limit": limit, | |
"offset": offset, | |
"where_document": where_document, | |
"include": include, | |
} | |
), | |
) | |
raise_chroma_error(resp) | |
body = resp.json() | |
return GetResult( | |
ids=body["ids"], | |
embeddings=body.get("embeddings", None), | |
metadatas=body.get("metadatas", None), | |
documents=body.get("documents", None), | |
) | |
def _delete( | |
self, | |
collection_id: UUID, | |
ids: Optional[IDs] = None, | |
where: Optional[Where] = {}, | |
where_document: Optional[WhereDocument] = {}, | |
) -> IDs: | |
"""Deletes embeddings from the database""" | |
resp = requests.post( | |
self._api_url + "/collections/" + str(collection_id) + "/delete", | |
data=json.dumps( | |
{"where": where, "ids": ids, "where_document": where_document} | |
), | |
) | |
raise_chroma_error(resp) | |
return cast(IDs, resp.json()) | |
def _add( | |
self, | |
ids: IDs, | |
collection_id: UUID, | |
embeddings: Embeddings, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
increment_index: bool = True, | |
) -> bool: | |
""" | |
Adds a batch of embeddings to the database | |
- pass in column oriented data lists | |
- by default, the index is progressively built up as you add more data. If for ingestion performance reasons you want to disable this, set increment_index to False | |
- and then manually create the index yourself with collection.create_index() | |
""" | |
resp = requests.post( | |
self._api_url + "/collections/" + str(collection_id) + "/add", | |
data=json.dumps( | |
{ | |
"ids": ids, | |
"embeddings": embeddings, | |
"metadatas": metadatas, | |
"documents": documents, | |
"increment_index": increment_index, | |
} | |
), | |
) | |
raise_chroma_error(resp) | |
return True | |
def _update( | |
self, | |
collection_id: UUID, | |
ids: IDs, | |
embeddings: Optional[Embeddings] = None, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
) -> bool: | |
""" | |
Updates a batch of embeddings in the database | |
- pass in column oriented data lists | |
""" | |
resp = requests.post( | |
self._api_url + "/collections/" + str(collection_id) + "/update", | |
data=json.dumps( | |
{ | |
"ids": ids, | |
"embeddings": embeddings, | |
"metadatas": metadatas, | |
"documents": documents, | |
} | |
), | |
) | |
resp.raise_for_status() | |
return True | |
def _upsert( | |
self, | |
collection_id: UUID, | |
ids: IDs, | |
embeddings: Embeddings, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
increment_index: bool = True, | |
) -> bool: | |
""" | |
Updates a batch of embeddings in the database | |
- pass in column oriented data lists | |
""" | |
resp = requests.post( | |
self._api_url + "/collections/" + str(collection_id) + "/upsert", | |
data=json.dumps( | |
{ | |
"ids": ids, | |
"embeddings": embeddings, | |
"metadatas": metadatas, | |
"documents": documents, | |
"increment_index": increment_index, | |
} | |
), | |
) | |
resp.raise_for_status() | |
return True | |
def _query( | |
self, | |
collection_id: UUID, | |
query_embeddings: Embeddings, | |
n_results: int = 10, | |
where: Optional[Where] = {}, | |
where_document: Optional[WhereDocument] = {}, | |
include: Include = ["metadatas", "documents", "distances"], | |
) -> QueryResult: | |
"""Gets the nearest neighbors of a single embedding""" | |
resp = requests.post( | |
self._api_url + "/collections/" + str(collection_id) + "/query", | |
data=json.dumps( | |
{ | |
"query_embeddings": query_embeddings, | |
"n_results": n_results, | |
"where": where, | |
"where_document": where_document, | |
"include": include, | |
} | |
), | |
) | |
raise_chroma_error(resp) | |
body = resp.json() | |
return QueryResult( | |
ids=body["ids"], | |
distances=body.get("distances", None), | |
embeddings=body.get("embeddings", None), | |
metadatas=body.get("metadatas", None), | |
documents=body.get("documents", None), | |
) | |
def reset(self) -> None: | |
"""Resets the database""" | |
resp = requests.post(self._api_url + "/reset") | |
raise_chroma_error(resp) | |
def persist(self) -> bool: | |
"""Persists the database""" | |
resp = requests.post(self._api_url + "/persist") | |
raise_chroma_error(resp) | |
return cast(bool, resp.json()) | |
def raw_sql(self, sql: str) -> pd.DataFrame: | |
"""Runs a raw SQL query against the database""" | |
resp = requests.post( | |
self._api_url + "/raw_sql", data=json.dumps({"raw_sql": sql}) | |
) | |
raise_chroma_error(resp) | |
return pd.DataFrame.from_dict(resp.json()) | |
def create_index(self, collection_name: str) -> bool: | |
"""Creates an index for the given space key""" | |
resp = requests.post( | |
self._api_url + "/collections/" + collection_name + "/create_index" | |
) | |
raise_chroma_error(resp) | |
return cast(bool, resp.json()) | |
def get_version(self) -> str: | |
"""Returns the version of the server""" | |
resp = requests.get(self._api_url + "/version") | |
raise_chroma_error(resp) | |
return cast(str, resp.json()) | |
def raise_chroma_error(resp: requests.Response) -> None: | |
"""Raises an error if the response is not ok, using a ChromaError if possible""" | |
if resp.ok: | |
return | |
chroma_error = None | |
try: | |
body = resp.json() | |
if "error" in body: | |
if body["error"] in errors.error_types: | |
chroma_error = errors.error_types[body["error"]](body["message"]) | |
except BaseException: | |
pass | |
if chroma_error: | |
raise chroma_error | |
try: | |
resp.raise_for_status() | |
except requests.HTTPError: | |
raise (Exception(resp.text)) | |