Spaces:
Running
Running
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
from langchain.embeddings import OpenAIEmbeddings | |
# from langchain_community.embeddings import OpenAIEmbeddings | |
from langchain.vectorstores import Chroma | |
from src.utils import brian_knows_system_message | |
from uuid import uuid4 | |
import chromadb | |
from chromadb.config import Settings | |
from chromadb.utils import embedding_functions | |
import sys | |
import os | |
import openai | |
import streamlit as st | |
import logging | |
sys.path.append("../..") | |
from dotenv import load_dotenv, find_dotenv | |
_ = load_dotenv(find_dotenv()) # read local .env file | |
# openai.api_key = os.environ["OPENAI_API_KEY"] | |
openai.api_key = st.secrets["OPENAI_API_KEY"] | |
openai_key = st.secrets["OPENAI_API_KEY"] | |
class VectordbManager: | |
def __init__( | |
self, | |
knowledge_base_name: str, | |
) -> None: | |
self.knowledge_base_name = knowledge_base_name | |
self.vector_db = None | |
def load_vectordb( | |
self, | |
embedding_function=OpenAIEmbeddings(), | |
): | |
client = chromadb.HttpClient( | |
host="chroma.brianknows.org", | |
port="443", | |
ssl=True, | |
settings=Settings(allow_reset=True), | |
) | |
vectordb = Chroma(embedding_function=embedding_function, client=client) | |
self.vector_db = vectordb | |
def load_collection(self, embedding_function=OpenAIEmbeddings()): | |
client = chromadb.HttpClient( | |
host="chroma.brianknows.org", | |
port=443, | |
ssl=True, | |
settings=Settings( | |
allow_reset=True, | |
), | |
) | |
collection = client.get_collection( | |
self.knowledge_base_name, | |
embedding_function=embedding_functions.OpenAIEmbeddingFunction( | |
api_key=openai_key | |
), | |
) | |
return collection | |
def create_vector_db(self, splits: list, knowledge_base_name: str): | |
logging.info("create_vector_db") | |
embedding_fn = OpenAIEmbeddings() | |
try: | |
client = chromadb.HttpClient( | |
host="chroma.brianknows.org", | |
port=443, | |
ssl=True, | |
settings=Settings( | |
allow_reset=True, | |
), | |
) | |
collection = client.get_or_create_collection( | |
knowledge_base_name, | |
embedding_function=embedding_functions.OpenAIEmbeddingFunction( | |
api_key=openai_key | |
), | |
) | |
ids = [] | |
metadatas = [] | |
documents = [] | |
for split in splits: | |
ids.append(str(uuid4())) | |
metadatas.append(split.metadata) | |
documents.append(split.page_content) | |
collection.add(documents=documents, ids=ids, metadatas=metadatas) | |
vector_db = Chroma.from_documents( | |
documents=splits, embedding=embedding_fn, client=client | |
) | |
self.vector_db = vector_db | |
except Exception as e: | |
logging.error(f"error in creating db: {str(e)}") | |
def add_splits_to_existing_vectordb( | |
self, | |
splits: list, | |
): | |
for split in splits: | |
try: | |
self.vector_db.add_documents([split]) | |
print("document loaded!") | |
except Exception as e: | |
print(f"Error with doc : {split}") | |
print(e) | |
def retrieve_docs_from_query(self, query: str, k=2, fetch_k=3) -> list: | |
""" | |
query : Text to look up documents similar to. | |
k : Number of Documents to return. Defaults to 4. | |
fetch_k : Number of Documents to fetch to pass to MMR algorithm. | |
lambda_mult : Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. | |
""" | |
retrieved_docs = self.vector_db.max_marginal_relevance_search( | |
query, k=k, fetch_k=fetch_k | |
) | |
return retrieved_docs | |
def retrieve_qa( | |
self, | |
llm, | |
query: str, | |
score_threshold: float = 0.65, | |
system_message=brian_knows_system_message, | |
): | |
"""return llm answer based on docs""" | |
# Build prompt | |
template = """You are a Web3 assistant. Use the following pieces of context to answer the question at \ | |
the end. If you don't know the answer, just say: "I don't know". Don't try to make up an \ | |
answer! Provide a always a detailed and comprehensive response. """ | |
fixed_template = """ {context} | |
Question: {question} | |
Detailed Answer:""" | |
template = system_message + fixed_template | |
QA_CHAIN_PROMPT = PromptTemplate.from_template(template) | |
# Run chain | |
qa_chain = RetrievalQA.from_chain_type( | |
llm, | |
retriever=self.vector_db.as_retriever( | |
search_type="similarity_score_threshold", | |
search_kwargs={"score_threshold": score_threshold}, | |
), | |
return_source_documents=True, | |
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}, | |
# reduce_k_below_max_tokens=True, | |
) | |
result = qa_chain({"query": query}) | |
return result | |