ChatData / lib /private_kb.py
lqhl's picture
Synced repo using 'sync_with_huggingface' Github Action
0e573d0 verified
import pandas as pd
import hashlib
import requests
from typing import List, Optional
from datetime import datetime
from langchain.schema.embeddings import Embeddings
from streamlit.runtime.uploaded_file_manager import UploadedFile
from clickhouse_connect import get_client
from multiprocessing.pool import ThreadPool
from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings
from .helper import create_retriever_tool
parser_url = "https://api.unstructured.io/general/v0/general"
def parse_files(api_key, user_id, files: List[UploadedFile]):
def parse_file(file: UploadedFile):
headers = {
"accept": "application/json",
"unstructured-api-key": api_key,
}
data = {"strategy": "auto", "ocr_languages": ["eng"]}
file_hash = hashlib.sha256(file.read()).hexdigest()
file_data = {"files": (file.name, file.getvalue(), file.type)}
response = requests.post(
parser_url, headers=headers, data=data, files=file_data
)
json_response = response.json()
if response.status_code != 200:
raise ValueError(str(json_response))
texts = [
{
"text": t["text"],
"file_name": t["metadata"]["filename"],
"entity_id": hashlib.sha256(
(file_hash + t["text"]).encode()
).hexdigest(),
"user_id": user_id,
"created_by": datetime.now(),
}
for t in json_response
if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10
]
return texts
with ThreadPool(8) as p:
rows = []
for r in p.imap_unordered(parse_file, files):
rows.extend(r)
return rows
def extract_embedding(embeddings: Embeddings, texts):
if len(texts) > 0:
embs = embeddings.embed_documents(
[t["text"] for _, t in enumerate(texts)])
for i, _ in enumerate(texts):
texts[i]["vector"] = embs[i]
return texts
raise ValueError("No texts extracted!")
class PrivateKnowledgeBase:
def __init__(
self,
host,
port,
username,
password,
embedding: Embeddings,
parser_api_key,
db="chat",
kb_table="private_kb",
tool_table="private_tool",
) -> None:
super().__init__()
kb_schema_ = f"""
CREATE TABLE IF NOT EXISTS {db}.{kb_table}(
entity_id String,
file_name String,
text String,
user_id String,
created_by DateTime,
vector Array(Float32),
CONSTRAINT cons_vec_len CHECK length(vector) = 768,
VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine')
) ENGINE = ReplacingMergeTree ORDER BY entity_id
"""
tool_schema_ = f"""
CREATE TABLE IF NOT EXISTS {db}.{tool_table}(
tool_id String,
tool_name String,
file_names Array(String),
user_id String,
created_by DateTime,
tool_description String
) ENGINE = ReplacingMergeTree ORDER BY tool_id
"""
self.kb_table = kb_table
self.tool_table = tool_table
config = MyScaleSettings(
host=host,
port=port,
username=username,
password=password,
database=db,
table=kb_table,
)
client = get_client(
host=config.host,
port=config.port,
username=config.username,
password=config.password,
)
client.command("SET allow_experimental_object_type=1")
client.command(kb_schema_)
client.command(tool_schema_)
self.parser_api_key = parser_api_key
self.vstore = MyScaleWithoutJSON(
embedding=embedding,
config=config,
must_have_cols=["file_name", "text", "created_by"],
)
def list_files(self, user_id, tool_name=None):
query = f"""
SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph,
arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars
FROM {self.vstore.config.database}.{self.kb_table}
WHERE user_id = '{user_id}' GROUP BY file_name
"""
return [r for r in self.vstore.client.query(query).named_results()]
def add_by_file(
self, user_id, files: List[UploadedFile], **kwargs
):
data = parse_files(self.parser_api_key, user_id, files)
data = extract_embedding(self.vstore.embeddings, data)
self.vstore.client.insert_df(
self.kb_table,
pd.DataFrame(data),
database=self.vstore.config.database,
)
def clear(self, user_id):
self.vstore.client.command(
f"DELETE FROM {self.vstore.config.database}.{self.kb_table} "
f"WHERE user_id='{user_id}'"
)
query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table}
WHERE user_id = '{user_id}'"""
self.vstore.client.command(query)
def create_tool(
self, user_id, tool_name, tool_description, files: Optional[List[str]] = None
):
self.vstore.client.insert_df(
self.tool_table,
pd.DataFrame(
[
{
"tool_id": hashlib.sha256(
(user_id + tool_name).encode("utf-8")
).hexdigest(),
"tool_name": tool_name,
"file_names": files,
"user_id": user_id,
"created_by": datetime.now(),
"tool_description": tool_description,
}
]
),
database=self.vstore.config.database,
)
def list_tools(self, user_id, tool_name=None):
extended_where = f"AND tool_name = '{tool_name}'" if tool_name else ""
query = f"""
SELECT tool_name, tool_description, length(file_names)
FROM {self.vstore.config.database}.{self.tool_table}
WHERE user_id = '{user_id}' {extended_where}
"""
return [r for r in self.vstore.client.query(query).named_results()]
def remove_tools(self, user_id, tool_names):
tool_names = ",".join([f"'{t}'" for t in tool_names])
query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table}
WHERE user_id = '{user_id}' AND tool_name IN [{tool_names}]"""
self.vstore.client.command(query)
def as_tools(self, user_id, tool_name=None):
tools = self.list_tools(user_id=user_id, tool_name=tool_name)
retrievers = {
t["tool_name"]: create_retriever_tool(
self.vstore.as_retriever(
search_kwargs={
"where_str": (
f"user_id='{user_id}' "
f"""AND file_name IN (
SELECT arrayJoin(file_names) FROM (
SELECT file_names
FROM {self.vstore.config.database}.{self.tool_table}
WHERE user_id = '{user_id}' AND tool_name = '{t['tool_name']}')
)"""
)
},
),
name=t["tool_name"],
description=t["tool_description"],
)
for t in tools
}
return retrievers