anubhav77 commited on
Commit
46ca867
·
1 Parent(s): 5f34b1a

Add list/get/create collections

Browse files
Files changed (2) hide show
  1. main.py +36 -4
  2. server.py +45 -2
main.py CHANGED
@@ -7,7 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware
7
  from sse_starlette.sse import EventSourceResponse
8
  from starlette.responses import StreamingResponse
9
  from pydantic import BaseModel
10
- from typing import List, Dict, Any, Generator, Optional
11
  from server import client
12
  from chromadb.api.types import (
13
  Documents,
@@ -22,6 +22,18 @@ from chromadb.api.types import (
22
  QueryResult,
23
  CollectionMetadata,
24
  )
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  app = fastapi.FastAPI(title="ChromaDB")
@@ -40,11 +52,31 @@ def heartbeat():
40
  print("Received heartbeat request")
41
  return bkend.heartbeat()
42
 
43
- @app.get(api_base+"/collection")
44
- def list_collection():
45
- print("Received list_collection request")
46
  return bkend.list_collections()
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  if __name__ == "__main__":
49
  uvicorn.run(app, host="0.0.0.0", port=8000)
50
 
 
7
  from sse_starlette.sse import EventSourceResponse
8
  from starlette.responses import StreamingResponse
9
  from pydantic import BaseModel
10
+ from typing import List, Dict, Any, Generator, Optional, cast
11
  from server import client
12
  from chromadb.api.types import (
13
  Documents,
 
22
  QueryResult,
23
  CollectionMetadata,
24
  )
25
+ from chromadb.api import API
26
+ from chromadb.config import System
27
+ import chromadb.utils.embedding_functions as ef
28
+ import pandas as pd
29
+ import requests
30
+ import json
31
+ from typing import Sequence
32
+ from chromadb.api.models.Collection import Collection
33
+ import chromadb.errors as errors
34
+ from uuid import UUID
35
+ from chromadb.telemetry import Telemetry
36
+ from overrides import override
37
 
38
 
39
  app = fastapi.FastAPI(title="ChromaDB")
 
52
  print("Received heartbeat request")
53
  return bkend.heartbeat()
54
 
55
+ @app.get(api_base+"/collections")
56
+ def list_collections():
57
+ print("Received list_collections request")
58
  return bkend.list_collections()
59
 
60
+ @app.post(api_base+"/collections")
61
+ def create_collection(
62
+ self,
63
+ name: str,
64
+ metadata: Optional[CollectionMetadata] = None,
65
+ embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
66
+ get_or_create: bool = False,
67
+ ) -> Collection:
68
+ print("Received request to create_collection")
69
+ return bkend.create_collection(name,metadata=metadata,embedding_function=embedding_function,get_or_create=get_or_create)
70
+
71
+ @app.get(api_base+"/collections")
72
+ def get_collection(
73
+ self,
74
+ name: str,
75
+ embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
76
+ ) -> Collection:
77
+ return bkend.get_collection(name,embedding_function=embedding_function
78
+ )
79
+
80
  if __name__ == "__main__":
81
  uvicorn.run(app, host="0.0.0.0", port=8000)
82
 
server.py CHANGED
@@ -1,5 +1,31 @@
1
  from chromadb.config import Settings
2
  import chromadb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import time
4
 
5
  class client():
@@ -10,7 +36,24 @@ class client():
10
  ))
11
 
12
  def heartbeat(self):
13
- return int(time.time_ns())
14
 
15
  def list_collections(self):
16
- return self.db.list_collections()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from chromadb.config import Settings
2
  import chromadb
3
+ from typing import Optional, cast
4
+ from chromadb.api import API
5
+ from chromadb.config import System
6
+ from chromadb.api.types import (
7
+ Documents,
8
+ Embeddings,
9
+ EmbeddingFunction,
10
+ IDs,
11
+ Include,
12
+ Metadatas,
13
+ Where,
14
+ WhereDocument,
15
+ GetResult,
16
+ QueryResult,
17
+ CollectionMetadata,
18
+ )
19
+ import chromadb.utils.embedding_functions as ef
20
+ import pandas as pd
21
+ import requests
22
+ import json
23
+ from typing import Sequence
24
+ from chromadb.api.models.Collection import Collection
25
+ import chromadb.errors as errors
26
+ from uuid import UUID
27
+ from chromadb.telemetry import Telemetry
28
+ from overrides import override
29
  import time
30
 
31
  class client():
 
36
  ))
37
 
38
  def heartbeat(self):
39
+ return {"nanosecond heartbeat":int(time.time_ns())}
40
 
41
  def list_collections(self):
42
+ return self.db.list_collections()
43
+
44
+ def create_collection(
45
+ self,
46
+ name: str,
47
+ metadata: Optional[CollectionMetadata] = None,
48
+ embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
49
+ get_or_create: bool = False,
50
+ ) -> Collection:
51
+ return self.db.create_collection(name,metadata=metadata,embedding_function=embedding_function,get_or_create=get_or_create)
52
+
53
+ def get_collection(
54
+ self,
55
+ name: str,
56
+ embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
57
+ ) -> Collection:
58
+ return self.db.get_collection(name,embedding_function=embedding_function)
59
+