import hashlib from datetime import datetime from typing import List, Optional import pandas as pd from clickhouse_connect import get_client from langchain.schema.embeddings import Embeddings from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings from streamlit.runtime.uploaded_file_manager import UploadedFile from backend.chat_bot.tools import parse_files, extract_embedding from backend.construct.build_retriever_tool import create_retriever_tool from logger import logger class ChatBotKnowledgeTable: def __init__(self, host, port, username, password, embedding: Embeddings, parser_api_key: str, db="chat", kb_table="private_kb", tool_table="private_tool") -> None: super().__init__() personal_files_schema_ = f""" CREATE TABLE IF NOT EXISTS {db}.{kb_table}( entity_id String, file_name String, text String, user_id String, created_by DateTime, vector Array(Float32), CONSTRAINT cons_vec_len CHECK length(vector) = 768, VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine') ) ENGINE = ReplacingMergeTree ORDER BY entity_id """ # `tool_name` represent private knowledge database name. private_knowledge_base_schema_ = f""" CREATE TABLE IF NOT EXISTS {db}.{tool_table}( tool_id String, tool_name String, file_names Array(String), user_id String, created_by DateTime, tool_description String ) ENGINE = ReplacingMergeTree ORDER BY tool_id """ self.personal_files_table = kb_table self.private_knowledge_base_table = tool_table config = MyScaleSettings( host=host, port=port, username=username, password=password, database=db, table=kb_table, ) self.client = get_client( host=config.host, port=config.port, username=config.username, password=config.password, ) self.client.command("SET allow_experimental_object_type=1") self.client.command(personal_files_schema_) self.client.command(private_knowledge_base_schema_) self.parser_api_key = parser_api_key self.vector_store = MyScaleWithoutJSON( embedding=embedding, config=config, must_have_cols=["file_name", "text", "created_by"], ) # List all files with given `user_id` def list_files(self, user_id: str): query = f""" SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph, arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars FROM {self.vector_store.config.database}.{self.personal_files_table} WHERE user_id = '{user_id}' GROUP BY file_name """ return [r for r in self.vector_store.client.query(query).named_results()] # Parse and embedding files def add_by_file(self, user_id, files: List[UploadedFile]): data = parse_files(self.parser_api_key, user_id, files) data = extract_embedding(self.vector_store.embeddings, data) self.vector_store.client.insert_df( table=self.personal_files_table, df=pd.DataFrame(data), database=self.vector_store.config.database, ) # Remove all files and private_knowledge_bases with given `user_id` def clear(self, user_id: str): self.vector_store.client.command( f"DELETE FROM {self.vector_store.config.database}.{self.personal_files_table} " f"WHERE user_id='{user_id}'" ) query = f"""DELETE FROM {self.vector_store.config.database}.{self.private_knowledge_base_table} WHERE user_id = '{user_id}'""" self.vector_store.client.command(query) def create_private_knowledge_base( self, user_id: str, tool_name: str, tool_description: str, files: Optional[List[str]] = None ): self.vector_store.client.insert_df( self.private_knowledge_base_table, pd.DataFrame( [ { "tool_id": hashlib.sha256( (user_id + tool_name).encode("utf-8") ).hexdigest(), "tool_name": tool_name, # tool_name represent user's private knowledge base. "file_names": files, "user_id": user_id, "created_by": datetime.now(), "tool_description": tool_description, } ] ), database=self.vector_store.config.database, ) # Show all private knowledge bases with given `user_id` def list_private_knowledge_bases(self, user_id: str, private_knowledge_base=None): extended_where = f"AND tool_name = '{private_knowledge_base}'" if private_knowledge_base else "" query = f""" SELECT tool_name, tool_description, length(file_names) FROM {self.vector_store.config.database}.{self.private_knowledge_base_table} WHERE user_id = '{user_id}' {extended_where} """ return [r for r in self.vector_store.client.query(query).named_results()] def remove_private_knowledge_bases(self, user_id: str, private_knowledge_bases: List[str]): unique_list = list(set(private_knowledge_bases)) unique_list = ",".join([f"'{t}'" for t in unique_list]) query = f"""DELETE FROM {self.vector_store.config.database}.{self.private_knowledge_base_table} WHERE user_id = '{user_id}' AND tool_name IN [{unique_list}]""" self.vector_store.client.command(query) def as_retrieval_tools(self, user_id, tool_name=None): logger.info(f"") private_knowledge_bases = self.list_private_knowledge_bases(user_id=user_id, private_knowledge_base=tool_name) retrievers = {} for private_kb in private_knowledge_bases: file_names_sql = f""" SELECT arrayJoin(file_names) FROM ( SELECT file_names FROM chat.private_tool WHERE user_id = '{user_id}' AND tool_name = '{private_kb["tool_name"]}' ) """ logger.info(f"user_id is {user_id}, file_names_sql is {file_names_sql}") res = self.client.query(file_names_sql) file_names = [] for line in res.result_rows: file_names.append(line[0]) file_names = ', '.join(f"'{item}'" for item in file_names) logger.info(f"user_id is {user_id}, file_names is {file_names}") retrievers[private_kb["tool_name"]] = create_retriever_tool( self.vector_store.as_retriever( search_kwargs={"where_str": f"user_id='{user_id}' AND file_name IN ({file_names})"}, ), tool_name=private_kb["tool_name"], description=private_kb["tool_description"], ) return retrievers