|
import os |
|
from pathlib import Path |
|
|
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.vectorstores import Chroma |
|
from langchain.llms.openai import OpenAIChat, OpenAI |
|
from langchain.document_loaders import PyPDFLoader, WebBaseLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.embeddings.openai import OpenAIEmbeddings |
|
from langchain.retrievers import ContextualCompressionRetriever |
|
from langchain.retrievers.document_compressors import LLMChainExtractor |
|
from langchain_experimental.text_splitter import SemanticChunker |
|
|
|
import streamlit as st |
|
|
|
|
|
LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath("vector_store") |
|
|
|
|
|
def load_documents(): |
|
loaders = [ |
|
PyPDFLoader(source_doc_url) |
|
if source_doc_url.endswith(".pdf") |
|
else WebBaseLoader(source_doc_url) |
|
for source_doc_url in st.session_state.source_doc_urls |
|
] |
|
documents = [] |
|
for loader in loaders: |
|
documents.extend(loader.load()) |
|
return documents |
|
|
|
|
|
def split_documents(documents): |
|
text_splitter = SemanticChunker(OpenAIEmbeddings()) |
|
texts = text_splitter.split_documents(documents) |
|
return texts |
|
|
|
|
|
def embeddings_on_local_vectordb(texts): |
|
vectordb = Chroma.from_documents( |
|
texts, |
|
embedding=OpenAIEmbeddings(), |
|
persist_directory=LOCAL_VECTOR_STORE_DIR.as_posix(), |
|
) |
|
vectordb.persist() |
|
retriever = ContextualCompressionRetriever( |
|
base_compressor=LLMChainExtractor.from_llm(OpenAI(temperature=0)), |
|
base_retriever=vectordb.as_retriever(search_kwargs={"k": 3}, search_type="mmr"), |
|
) |
|
return retriever |
|
|
|
|
|
def query_llm(retriever, query): |
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm=OpenAIChat(), |
|
retriever=retriever, |
|
return_source_documents=True, |
|
) |
|
relevant_docs = retriever.get_relevant_documents(query) |
|
result = qa_chain({"question": query, "chat_history": st.session_state.messages}) |
|
result = result["answer"] |
|
st.session_state.messages.append((query, result)) |
|
return relevant_docs, result |
|
|
|
|
|
def input_fields(): |
|
os.environ["OPENAI_API_KEY"] = "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" |
|
st.session_state.source_doc_urls = [ |
|
url.strip() for url in st.sidebar.text_input("Source Document URLs").split(",") |
|
] |
|
|
|
|
|
def process_documents(): |
|
try: |
|
documents = load_documents() |
|
texts = split_documents(documents) |
|
st.session_state.retriever = embeddings_on_local_vectordb(texts) |
|
except Exception as e: |
|
st.error(f"An error occurred: {e}") |
|
|
|
|
|
def boot(): |
|
st.title("Enigma Chatbot") |
|
input_fields() |
|
st.sidebar.button("Submit Documents", on_click=process_documents) |
|
st.sidebar.write("---") |
|
st.sidebar.write("References made during the chat will appear here") |
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
for message in st.session_state.messages: |
|
st.chat_message("human").write(message[0]) |
|
st.chat_message("ai").write(message[1]) |
|
if query := st.chat_input(): |
|
st.chat_message("human").write(query) |
|
references, response = query_llm(st.session_state.retriever, query) |
|
for doc in references: |
|
st.sidebar.info(f"Page {doc.metadata['page']}\n\n{doc.page_content}") |
|
st.chat_message("ai").write(response) |
|
|
|
|
|
if __name__ == "__main__": |
|
boot() |
|
|