marcellopoliti commited on
Commit
e04dd70
1 Parent(s): dd3a3a4

Add application file

Browse files
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
README.md CHANGED
@@ -1,12 +1,6 @@
1
- ---
2
- title: Brian Knows Collections
3
- emoji: 🔥
4
- colorFrom: red
5
- colorTo: yellow
6
- sdk: streamlit
7
- sdk_version: 1.32.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
1
+ # brian-knows-streamlit
 
 
 
 
 
 
 
 
 
2
 
3
+ UI for kb editing
4
+
5
+ todo:
6
+ store spltis with urls : check typescript repo
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from utils import get_chroma_client, get_embedding_function
3
+
4
+ # streamlit_app.py
5
+
6
+ import hmac
7
+ import streamlit as st
8
+
9
+ __import__("pysqlite3")
10
+ import sys
11
+
12
+ sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
13
+
14
+ st.set_page_config(page_title="Hello", page_icon="👋", layout="wide")
15
+
16
+
17
+ def check_password():
18
+ """Returns `True` if the user had the correct password."""
19
+
20
+ def password_entered():
21
+ """Checks whether a password entered by the user is correct."""
22
+ if hmac.compare_digest(st.session_state["password"], st.secrets["password"]):
23
+ st.session_state["password_correct"] = True
24
+ del st.session_state["password"] # Don't store the password.
25
+ else:
26
+ st.session_state["password_correct"] = False
27
+
28
+ # Return True if the password is validated.
29
+ if st.session_state.get("password_correct", False):
30
+ return True
31
+
32
+ # Show input for password.
33
+ st.text_input(
34
+ "Password", type="password", on_change=password_entered, key="password"
35
+ )
36
+ if "password_correct" in st.session_state:
37
+ st.error("😕 Password incorrect")
38
+ return False
39
+
40
+
41
+ if not check_password():
42
+ st.stop() # Do not continue if check_password is not True.
43
+
44
+ # Main Streamlit app starts here
45
+ st.write("# Brian Knowledge Base System! 👋")
46
+ client = get_chroma_client()
47
+ default_embedding_function = get_embedding_function()
conda.yml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: brian_knows
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - python=3.11.4
7
+ - pip
8
+ - pip:
9
+ - langchain>=0.1.6
10
+ - openai==1.14.2
11
+ - beautifulsoup4==4.12.2
12
+ - tiktoken==0.5.1
13
+ - chromadb>=0.4.22
14
+ - pandas==2.1.1
15
+ - streamlit==1.27.2
16
+ - python-dotenv==1.0.0
17
+ - fastapi==0.104.0
18
+ - uvicorn==0.23.2
19
+ - pypdf==3.16.4
20
+ - python-multipart==0.0.6
21
+ - matplotlib==3.8.3
22
+ - umap-learn==0.5.5
generate_kb.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ from services.document_manager.document_loader import DocumentsLoader
3
+
4
+ # from services.vectordb_manager.vectordb_manager import VectordbManager
5
+ import pandas as pd
6
+ import chromadb
7
+ from chromadb.config import Settings
8
+ import chromadb.utils.embedding_functions as embedding_functions
9
+ from dotenv import load_dotenv
10
+ import os
11
+ import requests
12
+ from bs4 import BeautifulSoup
13
+ from fastapi import FastAPI, File, UploadFile
14
+ import re
15
+ from main import client, default_embedding_function
16
+ import secrets
17
+
18
+
19
+ load_dotenv()
20
+ openai_key = os.getenv("OPENAI_API_KEY")
21
+
22
+
23
+ def generate_knowledge_box_from_url(
24
+ client,
25
+ kb_name: str,
26
+ urls: list,
27
+ embedding_fct=default_embedding_function,
28
+ chunk_size: int = 2_000,
29
+ ):
30
+ dl = DocumentsLoader()
31
+ docs = dl.load_docs(urls)
32
+ splits = dl.split_docs(docs, chunk_size=chunk_size)
33
+ contents = [split.page_content for split in splits]
34
+ metadatas = [split.metadata for split in splits]
35
+ cleaned_contents = [
36
+ re.sub(r"\n+", " ", content) for content in contents
37
+ ] # clean text a bit
38
+ chroma_collection = client.create_collection(
39
+ kb_name,
40
+ embedding_function=embedding_fct,
41
+ metadata={"hnsw:space": "cosine"},
42
+ )
43
+ ids = [secrets.token_hex(16) for _ in cleaned_contents]
44
+ chroma_collection.add(documents=cleaned_contents, ids=ids, metadatas=metadatas)
45
+ n_splits = chroma_collection.count()
46
+ return {"status": 200, "n_split": n_splits}
47
+
48
+
49
+ def add_links_to_knowledge_base(
50
+ client,
51
+ kb_name: str,
52
+ urls: list,
53
+ chunk_size: int = 2_000,
54
+ embedding_fct=default_embedding_function,
55
+ ):
56
+ dl = DocumentsLoader()
57
+ docs = dl.load_docs(urls)
58
+ splits = dl.split_docs(docs, chunk_size=chunk_size)
59
+ contents = [split.page_content for split in splits]
60
+ metadatas = [split.metadata for split in splits]
61
+ cleaned_contents = [
62
+ re.sub(r"\n+", " ", content) for content in contents
63
+ ] # clean text a bit
64
+ embeddings = default_embedding_function(cleaned_contents)
65
+ chroma_collection = client.get_collection(name=kb_name)
66
+ ids = [secrets.token_hex(16) for _ in cleaned_contents]
67
+ chroma_collection.add(
68
+ documents=cleaned_contents, embeddings=embeddings, ids=ids, metadatas=metadatas
69
+ )
70
+ n_splits = chroma_collection.count()
71
+ return {"status": 200, "n_split": n_splits}
72
+
73
+
74
+ if __name__ == "__main__":
75
+ df = pd.read_csv("test_marcello.csv")
76
+
77
+ kb_name = "new_new_test"
78
+ urls = df.values.tolist()
79
+ # res = generate_knowledge_box_from_url(
80
+ # client=client,
81
+ # urls=urls,
82
+ # kb_name=kb_name,
83
+ # embedding_fct=default_embedding_function,
84
+ # chunk_size=2_000,
85
+ # )
86
+
87
+ df = pd.read_csv("test2.csv")
88
+ urls = df.values.tolist()
89
+ res = add_links_to_knowledge_base(
90
+ client=client,
91
+ kb_name="test",
92
+ urls=urls,
93
+ )
pages/create_knowledge_box.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from main import client, default_embedding_function
3
+ import pandas as pd
4
+ from generate_kb import generate_knowledge_box_from_url
5
+
6
+ # Title of the app
7
+ st.title("Create a knowledge box from CSV file")
8
+
9
+ # File uploader widget
10
+ uploaded_file = st.file_uploader("Choose a CSV file", type=["csv"])
11
+ df = None
12
+
13
+ if uploaded_file is not None:
14
+ try:
15
+ df = pd.read_csv(uploaded_file)
16
+ st.write("DataFrame:")
17
+ st.write(df)
18
+ except Exception as e:
19
+ st.error(str(e))
20
+
21
+
22
+ if uploaded_file is not None:
23
+ st.text("dont use spaces but underscores _ in your new name")
24
+ kb_name = st.text_input(label="new knowledge base name")
25
+ if st.button("Generate new knowledge box"):
26
+ urls = df.values.tolist()
27
+ res = generate_knowledge_box_from_url(
28
+ client=client,
29
+ urls=urls,
30
+ kb_name=kb_name,
31
+ embedding_fct=default_embedding_function,
32
+ chunk_size=2_000,
33
+ )
34
+ st.json(res)
pages/delete_knowledge_box⚠️.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from retrieve_kb import get_current_knowledge_bases
3
+ from main import client
4
+
5
+
6
+ st.title("Delete knowledge Base ☠️")
7
+
8
+ st.title("Get knowledge boxes")
9
+ if st.button("Get current knowledge bases"):
10
+ kbs = get_current_knowledge_bases(client=client)
11
+ st.json(kbs)
12
+
13
+ collection_name = st.text_input(label="collection name")
14
+ if st.button("Delete Forever"):
15
+ client.delete_collection(collection_name)
16
+ st.success("Deleted")
pages/manage_knowledge_box.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from retrieve_kb import get_current_knowledge_bases, get_knowledge_base_information
3
+ from generate_kb import add_links_to_knowledge_base
4
+ from main import client, default_embedding_function
5
+ import pandas as pd
6
+
7
+ st.title("Get knowledge boxes")
8
+
9
+ if st.button("Get current knowledge bases"):
10
+ kbs = get_current_knowledge_bases(client=client)
11
+ st.json(kbs)
12
+
13
+ collection_name = st.text_input(label="knowledge base name")
14
+ info = {}
15
+ collection = None
16
+
17
+ if "df" not in st.session_state:
18
+ st.session_state["df"] = pd.DataFrame()
19
+
20
+ col1, col2 = st.columns(2)
21
+
22
+ if st.button("Get All"):
23
+ collection_info, coll = get_knowledge_base_information(
24
+ client=client,
25
+ embedding_function=default_embedding_function,
26
+ kb_name=collection_name,
27
+ )
28
+ st.session_state["collection"] = coll
29
+ collection = coll
30
+ # st.write(collection_info)
31
+ df = pd.DataFrame.from_records(collection_info)
32
+ df["source"] = df["metadatas"].apply(lambda x: x.get("source", "unkown"))
33
+ df["title"] = df["metadatas"].apply(lambda x: x.get("title", "unkown"))
34
+ df = df[["documents", "source", "title", "ids"]]
35
+ st.session_state["df"] = df
36
+
37
+
38
+ if len(st.session_state["df"]) != 0:
39
+ st.dataframe(st.session_state["df"], width=3_000)
40
+ unique_df = st.session_state["df"]["source"].unique()
41
+ st.text(f"unique urls: {len(unique_df)}")
42
+ st.dataframe(unique_df)
43
+ st.header("Remove a split")
44
+ id = st.text_input("Insert a split id")
45
+ if st.button("Remove Id from collection"):
46
+ if id in st.session_state["df"]["ids"].values.tolist():
47
+ res = st.session_state["collection"].delete(ids=[f"id"])
48
+ st.success(f"id {id} deleted")
49
+ else:
50
+ st.error(f"id {id} not in kb")
51
+
52
+ st.header("Add url to existing collection")
53
+ url_text = st.text_input("Insert a url link")
54
+ if st.button("add url to collection"):
55
+ urls = [url_text] # put in a list even if only one
56
+ res = add_links_to_knowledge_base(client=client, kb_name=collection_name, urls=urls)
57
+ st.write(res)
58
+
59
+
60
+ st.header("Add csv to existing collection")
61
+ uploaded_file = st.file_uploader("Choose a CSV file", type=["csv"])
62
+ df = None
63
+
64
+ if uploaded_file is not None:
65
+ try:
66
+ new_df = pd.read_csv(uploaded_file)
67
+ st.write("DataFrame:")
68
+ st.write(new_df)
69
+ except Exception as e:
70
+ st.error(str(e))
71
+ if st.button("add csv urls to collection"):
72
+ urls = new_df.values.tolist()
73
+ st.write(urls)
74
+ res = add_links_to_knowledge_base(
75
+ client=client, kb_name=collection_name, urls=urls
76
+ )
77
+ st.write(res)
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tiktoken>=0.5.1
2
+ pysqlite3-binary
3
+ langchain>=0.1.6
4
+ #sqlite3>=3.35.0
5
+ chromadb>=0.4.22
6
+ openai==1.14.2
7
+ beautifulsoup4==4.12.2
8
+ pandas>=2.1.1
9
+ streamlit>=1.27.2
10
+ python-dotenv==1.0.0
11
+ fastapi>=0.104.0
12
+ uvicorn>=0.23.2
13
+ #pypdf==3.16.4
14
+ #python-multipart==0.0.6
15
+ #matplotlib==3.8.3
16
+ #umap-learn==0.5.5
retrieve_kb.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ import chromadb
3
+ from chromadb.config import Settings
4
+ from utils import get_chroma_client, get_embedding_function
5
+
6
+
7
+ router = APIRouter()
8
+ default_embedding_function = get_embedding_function()
9
+
10
+
11
+ def get_current_knowledge_bases(client):
12
+ knowledge_boxes = client.list_collections()
13
+ return knowledge_boxes
14
+
15
+
16
+ def get_knowledge_base_information(
17
+ client, kb_name: str, embedding_function=default_embedding_function
18
+ ):
19
+ collection = client.get_collection(
20
+ name=kb_name, embedding_function=embedding_function
21
+ )
22
+
23
+ collection_info = collection.get(
24
+ include=["documents", "metadatas"]
25
+ ) # you can add "embeddings", "metadatas",
26
+
27
+ return collection_info, collection
28
+
29
+
30
+ if __name__ == "__main__":
31
+ client = get_chroma_client()
32
+ knowledge_boxes = get_current_knowledge_bases(client=client)
33
+ for kb in knowledge_boxes:
34
+ print(kb.name)
services/document_manager/document_loader.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from langchain.document_loaders import PyPDFLoader
3
+ import pandas as pd
4
+ from langchain.document_loaders import WebBaseLoader
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ import requests
7
+
8
+
9
+ class DocumentsLoader:
10
+ def __init__(self) -> None:
11
+ pass
12
+
13
+ def load_urls_from_csv(self, url_path: str, column: str = "url"):
14
+ df = pd.read_csv(url_path)
15
+ doc_urls = df[column].to_list()
16
+ return doc_urls
17
+
18
+ def is_notion_url(self, url):
19
+ # Regular expressions to match Notion URLs
20
+ notion_regex = r"https://(www\.)?(notion\.so|notion\.site)/"
21
+ # Check if the URL matches the Notion regex
22
+ return re.match(notion_regex, url) is not None
23
+
24
+ def is_pdf_url(self, url):
25
+ # Define a list of common PDF file extensions
26
+ pdf_extensions = [".pdf"]
27
+
28
+ # Check if the URL ends with a PDF file extension
29
+ for extension in pdf_extensions:
30
+ if url.endswith(extension):
31
+ return True
32
+
33
+ return False
34
+
35
+ def is_valid_url(self, url):
36
+ # TODO: handle status codes not 200
37
+ try:
38
+ response = requests.head(url)
39
+ if response.status_code == 200:
40
+ return True # 200 status code indicates a valid URL
41
+ except requests.RequestException:
42
+ return False
43
+
44
+ def load_docs(self, doc_urls: list) -> list:
45
+ web_urls, pdf_urls, docs = [], [], []
46
+ if isinstance(doc_urls[0], list):
47
+ doc_urls = [doc[0] for doc in doc_urls]
48
+ # doc_urls = doc_urls[0]
49
+ for url in doc_urls:
50
+ if self.is_pdf_url(url):
51
+ pdf_urls.append(url)
52
+ else:
53
+ web_urls.append(url)
54
+
55
+ if len(web_urls) > 0:
56
+ web_urls = [url for url in web_urls if self.is_valid_url(url)]
57
+ for web_url in web_urls:
58
+ try:
59
+ web_loader = WebBaseLoader(web_url)
60
+ web_docs = web_loader.load()
61
+ docs = docs + web_docs
62
+ except Exception as e:
63
+ print(f"Error web loader, {web_url}: {str(e)}")
64
+
65
+ if len(pdf_urls) > 0:
66
+ pdf_urls = [url for url in pdf_urls if self.is_valid_url(url)]
67
+ for pdf_url in pdf_urls:
68
+ try:
69
+ pdf_loader = PyPDFLoader(pdf_url)
70
+ pdf_docs = pdf_loader.load()
71
+ docs = docs + pdf_docs
72
+ except Exception as e:
73
+ print(f"Error pdf loader, {pdf_url}: {str(e)}")
74
+ return docs
75
+
76
+ def split_docs(self, docs, chunk_size=2000):
77
+ r_splitter = RecursiveCharacterTextSplitter(
78
+ chunk_size=chunk_size,
79
+ chunk_overlap=0,
80
+ separators=["\n\n", "\n", "\. ", " ", ""],
81
+ )
82
+ splits = r_splitter.split_documents(docs)
83
+ return splits
services/embedding_manager/embedding_manager.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from langchain.embeddings.openai import OpenAIEmbeddings
3
+ from utils import ModelName
4
+
5
+
6
+ class EmbeddingManager:
7
+ def __init__(self, model_name=ModelName.OPENAI) -> None:
8
+ self.model_name = model_name
9
+
10
+ def compare_embeddigns_similarity(self, embedding_1, embedding_2):
11
+ similarity = np.dot(embedding_1, embedding_2)
12
+ return similarity
13
+
14
+ def generate_embeddings(self, splits: list[str]):
15
+ embedding = None
16
+ if self.model_name == ModelName.OPENAI:
17
+ embedding = OpenAIEmbeddings()
18
+ embeddings = [embedding.embed_query(split) for split in splits]
19
+ return embeddings
services/vectordb_manager/vectordb_manager.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains import RetrievalQA
2
+ from langchain.prompts import PromptTemplate
3
+ from langchain.embeddings import OpenAIEmbeddings
4
+
5
+ # from langchain_community.embeddings import OpenAIEmbeddings
6
+ from langchain.vectorstores import Chroma
7
+ from src.utils import brian_knows_system_message
8
+ from uuid import uuid4
9
+
10
+ import chromadb
11
+ from chromadb.config import Settings
12
+ from chromadb.utils import embedding_functions
13
+
14
+ import sys
15
+ import os
16
+ import openai
17
+
18
+ import logging
19
+
20
+ sys.path.append("../..")
21
+ from dotenv import load_dotenv, find_dotenv
22
+
23
+ _ = load_dotenv(find_dotenv()) # read local .env file
24
+ openai.api_key = os.environ["OPENAI_API_KEY"]
25
+
26
+
27
+ class VectordbManager:
28
+ def __init__(
29
+ self,
30
+ knowledge_base_name: str,
31
+ ) -> None:
32
+ self.knowledge_base_name = knowledge_base_name
33
+ self.vector_db = None
34
+
35
+ def load_vectordb(
36
+ self,
37
+ embedding_function=OpenAIEmbeddings(),
38
+ ):
39
+ client = chromadb.HttpClient(
40
+ host="chroma.brianknows.org",
41
+ port="443",
42
+ ssl=True,
43
+ settings=Settings(allow_reset=True),
44
+ )
45
+ vectordb = Chroma(embedding_function=embedding_function, client=client)
46
+ self.vector_db = vectordb
47
+
48
+ def load_collection(self, embedding_function=OpenAIEmbeddings()):
49
+
50
+ client = chromadb.HttpClient(
51
+ host="chroma.brianknows.org",
52
+ port=443,
53
+ ssl=True,
54
+ settings=Settings(
55
+ allow_reset=True,
56
+ ),
57
+ )
58
+
59
+ collection = client.get_collection(
60
+ self.knowledge_base_name,
61
+ embedding_function=embedding_functions.OpenAIEmbeddingFunction(
62
+ api_key=os.environ["OPENAI_API_KEY"]
63
+ ),
64
+ )
65
+ return collection
66
+
67
+ def create_vector_db(self, splits: list, knowledge_base_name: str):
68
+ logging.info("create_vector_db")
69
+ embedding_fn = OpenAIEmbeddings()
70
+
71
+ try:
72
+ client = chromadb.HttpClient(
73
+ host="chroma.brianknows.org",
74
+ port=443,
75
+ ssl=True,
76
+ settings=Settings(
77
+ allow_reset=True,
78
+ ),
79
+ )
80
+ collection = client.get_or_create_collection(
81
+ knowledge_base_name,
82
+ embedding_function=embedding_functions.OpenAIEmbeddingFunction(
83
+ api_key=os.environ["OPENAI_API_KEY"]
84
+ ),
85
+ )
86
+
87
+ ids = []
88
+ metadatas = []
89
+ documents = []
90
+
91
+ for split in splits:
92
+ ids.append(str(uuid4()))
93
+ metadatas.append(split.metadata)
94
+ documents.append(split.page_content)
95
+ collection.add(documents=documents, ids=ids, metadatas=metadatas)
96
+ vector_db = Chroma.from_documents(
97
+ documents=splits, embedding=embedding_fn, client=client
98
+ )
99
+ self.vector_db = vector_db
100
+
101
+ except Exception as e:
102
+ logging.error(f"error in creating db: {str(e)}")
103
+
104
+ def add_splits_to_existing_vectordb(
105
+ self,
106
+ splits: list,
107
+ ):
108
+ for split in splits:
109
+ try:
110
+ self.vector_db.add_documents([split])
111
+ print("document loaded!")
112
+ except Exception as e:
113
+ print(f"Error with doc : {split}")
114
+ print(e)
115
+
116
+ def retrieve_docs_from_query(self, query: str, k=2, fetch_k=3) -> list:
117
+ """
118
+ query : Text to look up documents similar to.
119
+ k : Number of Documents to return. Defaults to 4.
120
+ fetch_k : Number of Documents to fetch to pass to MMR algorithm.
121
+ lambda_mult : Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
122
+ """
123
+ retrieved_docs = self.vector_db.max_marginal_relevance_search(
124
+ query, k=k, fetch_k=fetch_k
125
+ )
126
+ return retrieved_docs
127
+
128
+ def retrieve_qa(
129
+ self,
130
+ llm,
131
+ query: str,
132
+ score_threshold: float = 0.65,
133
+ system_message=brian_knows_system_message,
134
+ ):
135
+ """return llm answer based on docs"""
136
+
137
+ # Build prompt
138
+ template = """You are a Web3 assistant. Use the following pieces of context to answer the question at \
139
+ the end. If you don't know the answer, just say: "I don't know". Don't try to make up an \
140
+ answer! Provide a always a detailed and comprehensive response. """
141
+
142
+ fixed_template = """ {context}
143
+ Question: {question}
144
+ Detailed Answer:"""
145
+
146
+ template = system_message + fixed_template
147
+
148
+ QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
149
+
150
+ # Run chain
151
+ qa_chain = RetrievalQA.from_chain_type(
152
+ llm,
153
+ retriever=self.vector_db.as_retriever(
154
+ search_type="similarity_score_threshold",
155
+ search_kwargs={"score_threshold": score_threshold},
156
+ ),
157
+ return_source_documents=True,
158
+ chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
159
+ # reduce_k_below_max_tokens=True,
160
+ )
161
+ result = qa_chain({"query": query})
162
+
163
+ return result
test2.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ url
2
+ https://it.wikipedia.org/wiki/Michael_Jordan
3
+ https://en.wikipedia.org/wiki/Kobe_Bryant
test_marcello.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ url
2
+ https://en.wikipedia.org/wiki/Dragon_Ball
3
+ https://en.wikipedia.org/wiki/Naruto
utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ from chromadb.config import Settings
3
+ import chromadb.utils.embedding_functions as embedding_functions
4
+ from dotenv import load_dotenv
5
+ import os
6
+
7
+ load_dotenv()
8
+ openai_key = os.getenv("OPENAI_API_KEY")
9
+
10
+
11
+ def get_chroma_client(
12
+ host: str = "chroma.brianknows.org",
13
+ port: int = 443,
14
+ ) -> chromadb.HttpClient:
15
+ chroma_client = chromadb.HttpClient(
16
+ host=host,
17
+ port=443,
18
+ ssl=port,
19
+ settings=Settings(
20
+ allow_reset=True,
21
+ ),
22
+ )
23
+
24
+ return chroma_client
25
+
26
+
27
+ def get_embedding_function(model_name="text-embedding-ada-002"):
28
+ openai_ef = embedding_functions.OpenAIEmbeddingFunction(
29
+ api_key=openai_key, model_name=model_name
30
+ )
31
+ return openai_ef