Spaces:
Runtime error
Runtime error
import tcvectordb | |
from tcvectordb.model.database import Database | |
from tcvectordb.model.collection import Collection | |
from tcvectordb.model.index import Index, VectorIndex, FilterIndex, HNSWParams | |
from tcvectordb.model.enum import FieldType, IndexType, MetricType | |
VDB_ADDRESS = "vector_db.address" | |
VDB_KEY = "vector_db.key" | |
AI_DB_NAME = "vector_db.ai_db" | |
AI_COLLECTION_NAME = "vector_db.ai_graph_emb_collection" | |
class VectorDB: | |
def __init__(self, config): | |
self.address = config.get(VDB_ADDRESS) | |
self.key = config.get(VDB_KEY) | |
self.db_name = config.get(AI_DB_NAME) | |
self.ai_graph_emb_collection = config.get(AI_COLLECTION_NAME) | |
print(f"Try to connect vector db {self.address}") | |
self.client = self.create_client() | |
self._test_simple() | |
def create_client(self): | |
return tcvectordb.RPCVectorDBClient( | |
url=self.address, | |
username='root', | |
key=self.key, | |
timeout=30 | |
) | |
def _test_simple(self): | |
self.client.list_databases() | |
def init_database(self): | |
try: | |
self.client.create_database(self.db_name) | |
except tcvectordb.exceptions.VectorDBException: | |
self.client.drop_database(self.db_name) | |
self.client.create_database(self.db_name) | |
def init_graph_collection(self): | |
index = Index( | |
FilterIndex(name='id', field_type=FieldType.String, index_type=IndexType.PRIMARY_KEY), | |
FilterIndex(name='local_graph_path', field_type=FieldType.String, index_type=IndexType.FILTER), | |
VectorIndex(name='vector', dimension=512, index_type=IndexType.HNSW, | |
metric_type=MetricType.COSINE, params=HNSWParams(m=16, efconstruction=200)) | |
) | |
database: Database = self.client.database(self.db_name) | |
try: | |
database.create_collection(name=self.ai_graph_emb_collection ,shard=1,replicas=2,index=index, | |
description='this is a collection of graph embedding' | |
) | |
except tcvectordb.exceptions.VectorDBException: | |
database.drop_collection(self.ai_graph_emb_collection) | |
database.create_collection(name=self.ai_graph_emb_collection ,shard=1,replicas=2,index=index, | |
description='this is a collection of graph embedding' | |
) | |
def get_collection(self) -> Collection: | |
database: Database = self.client.database(self.db_name) | |
return database.collection(self.ai_graph_emb_collection) | |