|
import fastapi |
|
import json,time |
|
import uvicorn |
|
from fastapi import HTTPException , status |
|
from fastapi.responses import JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi import FastAPI as Response |
|
from sse_starlette.sse import EventSourceResponse |
|
from starlette.responses import StreamingResponse |
|
from starlette.requests import Request |
|
import chromadb |
|
from chromadb.config import Settings, System |
|
from pydantic import BaseModel |
|
from typing import List, Dict, Any, Generator, Optional, cast, Callable |
|
from chromadb.api.types import ( |
|
Documents, |
|
Embeddings, |
|
EmbeddingFunction, |
|
IDs, |
|
Include, |
|
Metadatas, |
|
Where, |
|
WhereDocument, |
|
GetResult, |
|
QueryResult, |
|
CollectionMetadata, |
|
) |
|
from chromadb.errors import ( |
|
ChromaError, |
|
InvalidUUIDError, |
|
InvalidDimensionException, |
|
) |
|
from chromadb.server.fastapi.types import ( |
|
AddEmbedding, |
|
DeleteEmbedding, |
|
GetEmbedding, |
|
QueryEmbedding, |
|
RawSql, |
|
CreateCollection, |
|
UpdateCollection, |
|
UpdateEmbedding, |
|
) |
|
from chromadb.api import API |
|
from chromadb.config import System |
|
import chromadb.utils.embedding_functions as ef |
|
import pandas as pd |
|
import requests |
|
import json,os |
|
from typing import Sequence |
|
from chromadb.api.models.Collection import Collection |
|
import chromadb.errors as errors |
|
from uuid import UUID |
|
from chromadb.telemetry import Telemetry |
|
from overrides import override |
|
import dropbox_handler as dbh |
|
|
|
async def catch_exceptions_middleware( |
|
request: Request, call_next: Callable[[Request], Any] |
|
) -> Response: |
|
try: |
|
return await call_next(request) |
|
except ChromaError as e: |
|
return JSONResponse( |
|
content={"error": e.name(), "message": e.message()}, status_code=e.code() |
|
) |
|
except Exception as e: |
|
return JSONResponse(content={"error": repr(e)}, status_code=500) |
|
|
|
|
|
|
|
def _uuid(uuid_str: str) -> UUID: |
|
try: |
|
return UUID(uuid_str) |
|
except ValueError: |
|
raise InvalidUUIDError(f"Could not parse {uuid_str} as a UUID") |
|
|
|
dbh.restoreFolder("/index/chroma") |
|
|
|
app = fastapi.FastAPI(title="ChromaDB") |
|
app.middleware("http")(catch_exceptions_middleware) |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
api_base="/api/v1" |
|
embedding_function=ef.DefaultEmbeddingFunction() |
|
bkend=chromadb.Client(Settings( |
|
chroma_db_impl="duckdb+parquet", |
|
persist_directory="./index/chroma" |
|
)) |
|
|
|
class PathRequest(BaseModel): |
|
dir: str = "/" |
|
|
|
@app.get(api_base+"") |
|
def heartbeat(): |
|
print("Received heartbeat request") |
|
return {"nanosecond heartbeat":int(time.time_ns())} |
|
|
|
@app.get("/info") |
|
def read_info(request: Request): |
|
print(request) |
|
request_method = request.method |
|
request_headers = request.headers |
|
request_body = request.body |
|
print(request_headers) |
|
response= JSONResponse(content= {"nanosecond heartbeat":int(time.time_ns())}) |
|
response.set_cookie(key="my_cookie", value="12345", httponly=True, secure=True) |
|
return response |
|
|
|
@app.post(api_base+"/reset") |
|
def reset(): |
|
print("Received reset request") |
|
dbh.restoreFolder("/index/chroma") |
|
return bkend.reset() |
|
|
|
@app.get(api_base+"/version") |
|
def version(): |
|
print("Received version request") |
|
return bkend.get_version() |
|
|
|
@app.post(api_base+"/persist") |
|
def persist(): |
|
print("Received persist request") |
|
retVal=bkend.persist() |
|
dbh.backupFolder("/index/chroma") |
|
return retVal |
|
|
|
@app.post(api_base+"/raw_sql") |
|
def raw_sql(raw_sql: RawSql): |
|
print("Received raw_sql request") |
|
return bkend.raw_sql(raw_sql.raw_sql) |
|
|
|
@app.post(api_base+"/walk") |
|
def walk(path: PathRequest): |
|
print("Received walk request") |
|
dirs=[] |
|
try: |
|
for root, items, files in os.walk(path.dir,topdown=True): |
|
for item in items: |
|
dirs.append(item) |
|
except Exception: |
|
print("got exception",Exception) |
|
response= JSONResponse(content= {"dirs":dirs}) |
|
return response |
|
|
|
@app.get(api_base+"/heartbeat") |
|
def heartbeat1(): |
|
print("Received heartbeat1 request") |
|
return heartbeat() |
|
|
|
@app.get(api_base+"/collections") |
|
def list_collections(request: Request): |
|
print("Received list_collections request") |
|
print(request.cookies) |
|
return bkend.list_collections() |
|
|
|
@app.post(api_base+"/collections") |
|
def create_collection( request: Request, collection: CreateCollection ) : |
|
print("Received request to create_collection") |
|
print(request.cookies) |
|
return bkend.create_collection(name=collection.name,metadata=collection.metadata,embedding_function=embedding_function,get_or_create=collection.get_or_create) |
|
|
|
@app.get(api_base+"/collections/{collection_name}") |
|
def get_collection( request: Request,collection_name: str) : |
|
print("Received get_collection request") |
|
print(request.cookies) |
|
return bkend.get_collection(collection_name,embedding_function=embedding_function) |
|
|
|
@app.post(api_base+"/collections/{collection_id}/add") |
|
def add(collection_id:str , add:AddEmbedding) -> None: |
|
print("Received add request") |
|
try: |
|
result=bkend._add(collection_id=_uuid(collection_id),embeddings=add.embeddings,metadatas=add.metadatas,documents=add.documents,ids=add.ids,increment_index=add.increment_index) |
|
except InvalidDimensionException as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
print(result) |
|
return result |
|
|
|
@app.post(api_base+"/collections/{collection_id}/update") |
|
def update(collection_id:str , update:UpdateEmbedding) -> None: |
|
print("Received update request") |
|
return bkend._update(ids=update.ids, collection_id=_uuid(collection_id), embeddings=update.embeddings, documents=update.documents, metadatas=update.metadatas) |
|
|
|
@app.post(api_base+"/collections/{collection_id}/upsert") |
|
def upsert(collection_id:str, upsert: AddEmbedding): |
|
print("Received upsert request") |
|
|
|
|
|
|
|
|
|
return bkend._upsert(collection_id=_uuid(collection_id),embeddings=upsert.embeddings,metadatas=upsert.metadatas,documents=upsert.documents,ids=upsert.ids,increment_index=upsert.increment_index) |
|
|
|
@app.post(api_base+"/collections/{collection_id}/get") |
|
def get( collection_id: str, get: GetEmbedding) -> GetResult: |
|
print("Received get request") |
|
return bkend._get(collection_id=_uuid(collection_id), ids=get.ids, where=get.where, |
|
where_document=get.where_document, sort=get.sort, limit=get.limit, |
|
offset=get.offset, include=get.include) |
|
|
|
@app.post(api_base+"/collections/{collection_id}/delete") |
|
def delete(collection_id: str, delete: DeleteEmbedding) -> List[UUID]: |
|
print("Received delete request") |
|
return bkend._delete(where=delete.where, ids=delete.ids, |
|
collection_id=_uuid(collection_id), where_document=delete.where_document) |
|
|
|
@app.get(api_base+"/collections/{collection_id}/count") |
|
def count(collection_id:str) ->int: |
|
print("Received count request") |
|
return bkend._count(_uuid(collection_id)) |
|
|
|
@app.post(api_base+"/collections/{collection_id}/query") |
|
def get_nearest_neighbors(collection_id: str, query: QueryEmbedding) -> QueryResult: |
|
print("Received get_nearest_neighbors request") |
|
return bkend._query(collection_id=_uuid(collection_id), where=query.where, where_document=query.where_document, |
|
query_embeddings=query.query_embeddings, n_results=query.n_results, include=query.include) |
|
|
|
@app.post(api_base+"/collections/{collection_name}/create_index") |
|
def create_index(collection_name:str)-> bool: |
|
print("Received create_index request") |
|
return bkend.create_index(collection_name) |
|
|
|
@app.put(api_base+"/collections/{collection_id}") |
|
def modify(collection_id: str, collection: UpdateCollection) -> None: |
|
print("Received modify-collection request") |
|
return bkend._modify(id=_uuid(collection_id), new_name=collection.new_name, new_metadata=collection.new_metadata) |
|
|
|
@app.delete(api_base+"/collections/{collection_name}") |
|
def delete_collection(collection_name:str) -> None: |
|
print("Received delete_collection request") |
|
return bkend.delete_collection(collection_name) |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|
|