TencentVDB_graph_search / vector_db /vector_db_client.py
qcloud
1
5bfdfae
import tcvectordb
from tcvectordb.model.database import Database
from tcvectordb.model.collection import Collection
from tcvectordb.model.index import Index, VectorIndex, FilterIndex, HNSWParams
from tcvectordb.model.enum import FieldType, IndexType, MetricType
VDB_ADDRESS = "vector_db.address"
VDB_KEY = "vector_db.key"
AI_DB_NAME = "vector_db.ai_db"
AI_COLLECTION_NAME = "vector_db.ai_graph_emb_collection"
class VectorDB:
def __init__(self, config):
self.address = config.get(VDB_ADDRESS)
self.key = config.get(VDB_KEY)
self.db_name = config.get(AI_DB_NAME)
self.ai_graph_emb_collection = config.get(AI_COLLECTION_NAME)
print(f"Try to connect vector db {self.address}")
self.client = self.create_client()
self._test_simple()
def create_client(self):
return tcvectordb.RPCVectorDBClient(
url=self.address,
username='root',
key=self.key,
timeout=30
)
def _test_simple(self):
self.client.list_databases()
def init_database(self):
try:
self.client.create_database(self.db_name)
except tcvectordb.exceptions.VectorDBException:
self.client.drop_database(self.db_name)
self.client.create_database(self.db_name)
def init_graph_collection(self):
index = Index(
FilterIndex(name='id', field_type=FieldType.String, index_type=IndexType.PRIMARY_KEY),
FilterIndex(name='local_graph_path', field_type=FieldType.String, index_type=IndexType.FILTER),
VectorIndex(name='vector', dimension=512, index_type=IndexType.HNSW,
metric_type=MetricType.COSINE, params=HNSWParams(m=16, efconstruction=200))
)
database: Database = self.client.database(self.db_name)
try:
database.create_collection(name=self.ai_graph_emb_collection ,shard=1,replicas=2,index=index,
description='this is a collection of graph embedding'
)
except tcvectordb.exceptions.VectorDBException:
database.drop_collection(self.ai_graph_emb_collection)
database.create_collection(name=self.ai_graph_emb_collection ,shard=1,replicas=2,index=index,
description='this is a collection of graph embedding'
)
def get_collection(self) -> Collection:
database: Database = self.client.database(self.db_name)
return database.collection(self.ai_graph_emb_collection)