marcellopoliti's picture
feat: add pdf docs
c5a0a6e
raw
history blame
5.34 kB
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.embeddings import OpenAIEmbeddings
# from langchain_community.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from src.utils import brian_knows_system_message
from uuid import uuid4
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions
import sys
import os
import openai
import streamlit as st
import logging
sys.path.append("../..")
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv()) # read local .env file
# openai.api_key = os.environ["OPENAI_API_KEY"]
openai.api_key = st.secrets["OPENAI_API_KEY"]
openai_key = st.secrets["OPENAI_API_KEY"]
class VectordbManager:
def __init__(
self,
knowledge_base_name: str,
) -> None:
self.knowledge_base_name = knowledge_base_name
self.vector_db = None
def load_vectordb(
self,
embedding_function=OpenAIEmbeddings(),
):
client = chromadb.HttpClient(
host="chroma.brianknows.org",
port="443",
ssl=True,
settings=Settings(allow_reset=True),
)
vectordb = Chroma(embedding_function=embedding_function, client=client)
self.vector_db = vectordb
def load_collection(self, embedding_function=OpenAIEmbeddings()):
client = chromadb.HttpClient(
host="chroma.brianknows.org",
port=443,
ssl=True,
settings=Settings(
allow_reset=True,
),
)
collection = client.get_collection(
self.knowledge_base_name,
embedding_function=embedding_functions.OpenAIEmbeddingFunction(
api_key=openai_key
),
)
return collection
def create_vector_db(self, splits: list, knowledge_base_name: str):
logging.info("create_vector_db")
embedding_fn = OpenAIEmbeddings()
try:
client = chromadb.HttpClient(
host="chroma.brianknows.org",
port=443,
ssl=True,
settings=Settings(
allow_reset=True,
),
)
collection = client.get_or_create_collection(
knowledge_base_name,
embedding_function=embedding_functions.OpenAIEmbeddingFunction(
api_key=openai_key
),
)
ids = []
metadatas = []
documents = []
for split in splits:
ids.append(str(uuid4()))
metadatas.append(split.metadata)
documents.append(split.page_content)
collection.add(documents=documents, ids=ids, metadatas=metadatas)
vector_db = Chroma.from_documents(
documents=splits, embedding=embedding_fn, client=client
)
self.vector_db = vector_db
except Exception as e:
logging.error(f"error in creating db: {str(e)}")
def add_splits_to_existing_vectordb(
self,
splits: list,
):
for split in splits:
try:
self.vector_db.add_documents([split])
print("document loaded!")
except Exception as e:
print(f"Error with doc : {split}")
print(e)
def retrieve_docs_from_query(self, query: str, k=2, fetch_k=3) -> list:
"""
query : Text to look up documents similar to.
k : Number of Documents to return. Defaults to 4.
fetch_k : Number of Documents to fetch to pass to MMR algorithm.
lambda_mult : Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
"""
retrieved_docs = self.vector_db.max_marginal_relevance_search(
query, k=k, fetch_k=fetch_k
)
return retrieved_docs
def retrieve_qa(
self,
llm,
query: str,
score_threshold: float = 0.65,
system_message=brian_knows_system_message,
):
"""return llm answer based on docs"""
# Build prompt
template = """You are a Web3 assistant. Use the following pieces of context to answer the question at \
the end. If you don't know the answer, just say: "I don't know". Don't try to make up an \
answer! Provide a always a detailed and comprehensive response. """
fixed_template = """ {context}
Question: {question}
Detailed Answer:"""
template = system_message + fixed_template
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
# Run chain
qa_chain = RetrievalQA.from_chain_type(
llm,
retriever=self.vector_db.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"score_threshold": score_threshold},
),
return_source_documents=True,
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
# reduce_k_below_max_tokens=True,
)
result = qa_chain({"query": query})
return result