import fastapi import json,time import uvicorn from fastapi import HTTPException , status from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi import FastAPI as Response from sse_starlette.sse import EventSourceResponse from starlette.responses import StreamingResponse from starlette.requests import Request import chromadb from chromadb.config import Settings, System from pydantic import BaseModel from typing import List, Dict, Any, Generator, Optional, cast, Callable from chromadb.api.types import ( Documents, Embeddings, EmbeddingFunction, IDs, Include, Metadatas, Where, WhereDocument, GetResult, QueryResult, CollectionMetadata, ) from chromadb.errors import ( ChromaError, InvalidUUIDError, InvalidDimensionException, ) from chromadb.server.fastapi.types import ( AddEmbedding, DeleteEmbedding, GetEmbedding, QueryEmbedding, RawSql, # Results, CreateCollection, UpdateCollection, UpdateEmbedding, ) from chromadb.api import API from chromadb.config import System import chromadb.utils.embedding_functions as ef import pandas as pd import requests import json,os from typing import Sequence from chromadb.api.models.Collection import Collection import chromadb.errors as errors from uuid import UUID from chromadb.telemetry import Telemetry from overrides import override import dropbox_handler as dbh async def catch_exceptions_middleware( request: Request, call_next: Callable[[Request], Any] ) -> Response: try: return await call_next(request) except ChromaError as e: return JSONResponse( content={"error": e.name(), "message": e.message()}, status_code=e.code() ) except Exception as e: return JSONResponse(content={"error": repr(e)}, status_code=500) def _uuid(uuid_str: str) -> UUID: try: return UUID(uuid_str) except ValueError: raise InvalidUUIDError(f"Could not parse {uuid_str} as a UUID") dbh.restoreFolder("/index/chroma") app = fastapi.FastAPI(title="ChromaDB") app.middleware("http")(catch_exceptions_middleware) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) api_base="/api/v1" embedding_function=ef.DefaultEmbeddingFunction() bkend=chromadb.Client(Settings( chroma_db_impl="duckdb+parquet", persist_directory="./index/chroma" # Optional, defaults to .chromadb/ in the current directory )) class PathRequest(BaseModel): dir: str = "/" @app.get(api_base+"") def heartbeat(): print("Received heartbeat request") return {"nanosecond heartbeat":int(time.time_ns())} @app.get("/info") def read_info(request: Request): print(request) request_method = request.method request_headers = request.headers request_body = request.body print(request_headers) response= JSONResponse(content= {"nanosecond heartbeat":int(time.time_ns())}) response.set_cookie(key="my_cookie", value="12345", httponly=True, secure=True) return response @app.post(api_base+"/reset") def reset(): print("Received reset request") dbh.restoreFolder("/index/chroma") return bkend.reset() @app.get(api_base+"/version") def version(): print("Received version request") return bkend.get_version() @app.post(api_base+"/persist") def persist(): print("Received persist request") retVal=bkend.persist() dbh.backupFolder("/index/chroma") return retVal @app.post(api_base+"/raw_sql") def raw_sql(raw_sql: RawSql): print("Received raw_sql request") return bkend.raw_sql(raw_sql.raw_sql) @app.post(api_base+"/walk") def walk(path: PathRequest): print("Received walk request") dirs=[] try: for root, items, files in os.walk(path.dir,topdown=True): for item in items: dirs.append(item) except Exception: print("got exception",Exception) response= JSONResponse(content= {"dirs":dirs}) return response @app.get(api_base+"/heartbeat") def heartbeat1(): print("Received heartbeat1 request") return heartbeat() @app.get(api_base+"/collections") def list_collections(request: Request): print("Received list_collections request") print(request.cookies) return bkend.list_collections() @app.post(api_base+"/collections") def create_collection( request: Request, collection: CreateCollection ) : print("Received request to create_collection") print(request.cookies) return bkend.create_collection(name=collection.name,metadata=collection.metadata,embedding_function=embedding_function,get_or_create=collection.get_or_create) @app.get(api_base+"/collections/{collection_name}") def get_collection( request: Request,collection_name: str) : print("Received get_collection request") print(request.cookies) return bkend.get_collection(collection_name,embedding_function=embedding_function) @app.post(api_base+"/collections/{collection_id}/add") def add(collection_id:str , add:AddEmbedding) -> None: print("Received add request") try: result=bkend._add(collection_id=_uuid(collection_id),embeddings=add.embeddings,metadatas=add.metadatas,documents=add.documents,ids=add.ids,increment_index=add.increment_index) except InvalidDimensionException as e: raise HTTPException(status_code=500, detail=str(e)) print(result) return result @app.post(api_base+"/collections/{collection_id}/update") def update(collection_id:str , update:UpdateEmbedding) -> None: print("Received update request") return bkend._update(ids=update.ids, collection_id=_uuid(collection_id), embeddings=update.embeddings, documents=update.documents, metadatas=update.metadatas) @app.post(api_base+"/collections/{collection_id}/upsert") def upsert(collection_id:str, upsert: AddEmbedding): print("Received upsert request") #try: # print("Received upsert request",upsert) #except Exception: # pass return bkend._upsert(collection_id=_uuid(collection_id),embeddings=upsert.embeddings,metadatas=upsert.metadatas,documents=upsert.documents,ids=upsert.ids,increment_index=upsert.increment_index) @app.post(api_base+"/collections/{collection_id}/get") def get( collection_id: str, get: GetEmbedding) -> GetResult: print("Received get request") return bkend._get(collection_id=_uuid(collection_id), ids=get.ids, where=get.where, where_document=get.where_document, sort=get.sort, limit=get.limit, offset=get.offset, include=get.include) @app.post(api_base+"/collections/{collection_id}/delete") def delete(collection_id: str, delete: DeleteEmbedding) -> List[UUID]: print("Received delete request") return bkend._delete(where=delete.where, ids=delete.ids, collection_id=_uuid(collection_id), where_document=delete.where_document) @app.get(api_base+"/collections/{collection_id}/count") def count(collection_id:str) ->int: print("Received count request") return bkend._count(_uuid(collection_id)) @app.post(api_base+"/collections/{collection_id}/query") def get_nearest_neighbors(collection_id: str, query: QueryEmbedding) -> QueryResult: print("Received get_nearest_neighbors request") return bkend._query(collection_id=_uuid(collection_id), where=query.where, where_document=query.where_document, query_embeddings=query.query_embeddings, n_results=query.n_results, include=query.include) @app.post(api_base+"/collections/{collection_name}/create_index") def create_index(collection_name:str)-> bool: print("Received create_index request") return bkend.create_index(collection_name) @app.put(api_base+"/collections/{collection_id}") def modify(collection_id: str, collection: UpdateCollection) -> None: print("Received modify-collection request") return bkend._modify(id=_uuid(collection_id), new_name=collection.new_name, new_metadata=collection.new_metadata) @app.delete(api_base+"/collections/{collection_name}") def delete_collection(collection_name:str) -> None: print("Received delete_collection request") return bkend.delete_collection(collection_name) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)