Spaces:
Running
Running
import pandas as pd | |
import hashlib | |
import requests | |
from typing import List, Optional | |
from datetime import datetime | |
from langchain.schema.embeddings import Embeddings | |
from streamlit.runtime.uploaded_file_manager import UploadedFile | |
from clickhouse_connect import get_client | |
from multiprocessing.pool import ThreadPool | |
from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings | |
from .helper import create_retriever_tool | |
parser_url = "https://api.unstructured.io/general/v0/general" | |
def parse_files(api_key, user_id, files: List[UploadedFile]): | |
def parse_file(file: UploadedFile): | |
headers = { | |
"accept": "application/json", | |
"unstructured-api-key": api_key, | |
} | |
data = {"strategy": "auto", "ocr_languages": ["eng"]} | |
file_hash = hashlib.sha256(file.read()).hexdigest() | |
file_data = {"files": (file.name, file.getvalue(), file.type)} | |
response = requests.post( | |
parser_url, headers=headers, data=data, files=file_data | |
) | |
json_response = response.json() | |
if response.status_code != 200: | |
raise ValueError(str(json_response)) | |
texts = [ | |
{ | |
"text": t["text"], | |
"file_name": t["metadata"]["filename"], | |
"entity_id": hashlib.sha256( | |
(file_hash + t["text"]).encode() | |
).hexdigest(), | |
"user_id": user_id, | |
"created_by": datetime.now(), | |
} | |
for t in json_response | |
if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10 | |
] | |
return texts | |
with ThreadPool(8) as p: | |
rows = [] | |
for r in p.imap_unordered(parse_file, files): | |
rows.extend(r) | |
return rows | |
def extract_embedding(embeddings: Embeddings, texts): | |
if len(texts) > 0: | |
embs = embeddings.embed_documents( | |
[t["text"] for _, t in enumerate(texts)]) | |
for i, _ in enumerate(texts): | |
texts[i]["vector"] = embs[i] | |
return texts | |
raise ValueError("No texts extracted!") | |
class PrivateKnowledgeBase: | |
def __init__( | |
self, | |
host, | |
port, | |
username, | |
password, | |
embedding: Embeddings, | |
parser_api_key, | |
db="chat", | |
kb_table="private_kb", | |
tool_table="private_tool", | |
) -> None: | |
super().__init__() | |
kb_schema_ = f""" | |
CREATE TABLE IF NOT EXISTS {db}.{kb_table}( | |
entity_id String, | |
file_name String, | |
text String, | |
user_id String, | |
created_by DateTime, | |
vector Array(Float32), | |
CONSTRAINT cons_vec_len CHECK length(vector) = 768, | |
VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine') | |
) ENGINE = ReplacingMergeTree ORDER BY entity_id | |
""" | |
tool_schema_ = f""" | |
CREATE TABLE IF NOT EXISTS {db}.{tool_table}( | |
tool_id String, | |
tool_name String, | |
file_names Array(String), | |
user_id String, | |
created_by DateTime, | |
tool_description String | |
) ENGINE = ReplacingMergeTree ORDER BY tool_id | |
""" | |
self.kb_table = kb_table | |
self.tool_table = tool_table | |
config = MyScaleSettings( | |
host=host, | |
port=port, | |
username=username, | |
password=password, | |
database=db, | |
table=kb_table, | |
) | |
client = get_client( | |
host=config.host, | |
port=config.port, | |
username=config.username, | |
password=config.password, | |
) | |
client.command("SET allow_experimental_object_type=1") | |
client.command(kb_schema_) | |
client.command(tool_schema_) | |
self.parser_api_key = parser_api_key | |
self.vstore = MyScaleWithoutJSON( | |
embedding=embedding, | |
config=config, | |
must_have_cols=["file_name", "text", "created_by"], | |
) | |
def list_files(self, user_id, tool_name=None): | |
query = f""" | |
SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph, | |
arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars | |
FROM {self.vstore.config.database}.{self.kb_table} | |
WHERE user_id = '{user_id}' GROUP BY file_name | |
""" | |
return [r for r in self.vstore.client.query(query).named_results()] | |
def add_by_file( | |
self, user_id, files: List[UploadedFile], **kwargs | |
): | |
data = parse_files(self.parser_api_key, user_id, files) | |
data = extract_embedding(self.vstore.embeddings, data) | |
self.vstore.client.insert_df( | |
self.kb_table, | |
pd.DataFrame(data), | |
database=self.vstore.config.database, | |
) | |
def clear(self, user_id): | |
self.vstore.client.command( | |
f"DELETE FROM {self.vstore.config.database}.{self.kb_table} " | |
f"WHERE user_id='{user_id}'" | |
) | |
query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table} | |
WHERE user_id = '{user_id}'""" | |
self.vstore.client.command(query) | |
def create_tool( | |
self, user_id, tool_name, tool_description, files: Optional[List[str]] = None | |
): | |
self.vstore.client.insert_df( | |
self.tool_table, | |
pd.DataFrame( | |
[ | |
{ | |
"tool_id": hashlib.sha256( | |
(user_id + tool_name).encode("utf-8") | |
).hexdigest(), | |
"tool_name": tool_name, | |
"file_names": files, | |
"user_id": user_id, | |
"created_by": datetime.now(), | |
"tool_description": tool_description, | |
} | |
] | |
), | |
database=self.vstore.config.database, | |
) | |
def list_tools(self, user_id, tool_name=None): | |
extended_where = f"AND tool_name = '{tool_name}'" if tool_name else "" | |
query = f""" | |
SELECT tool_name, tool_description, length(file_names) | |
FROM {self.vstore.config.database}.{self.tool_table} | |
WHERE user_id = '{user_id}' {extended_where} | |
""" | |
return [r for r in self.vstore.client.query(query).named_results()] | |
def remove_tools(self, user_id, tool_names): | |
tool_names = ",".join([f"'{t}'" for t in tool_names]) | |
query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table} | |
WHERE user_id = '{user_id}' AND tool_name IN [{tool_names}]""" | |
self.vstore.client.command(query) | |
def as_tools(self, user_id, tool_name=None): | |
tools = self.list_tools(user_id=user_id, tool_name=tool_name) | |
retrievers = { | |
t["tool_name"]: create_retriever_tool( | |
self.vstore.as_retriever( | |
search_kwargs={ | |
"where_str": ( | |
f"user_id='{user_id}' " | |
f"""AND file_name IN ( | |
SELECT arrayJoin(file_names) FROM ( | |
SELECT file_names | |
FROM {self.vstore.config.database}.{self.tool_table} | |
WHERE user_id = '{user_id}' AND tool_name = '{t['tool_name']}') | |
)""" | |
) | |
}, | |
), | |
name=t["tool_name"], | |
description=t["tool_description"], | |
) | |
for t in tools | |
} | |
return retrievers | |