Zeta / app.py
Ritvik19's picture
Upload 2 files
ecb7a48 verified
raw
history blame
3.42 kB
import os
from pathlib import Path
from langchain.chains import ConversationalRetrievalChain
from langchain.vectorstores import Chroma
from langchain.llms.openai import OpenAIChat, OpenAI
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain_experimental.text_splitter import SemanticChunker
import streamlit as st
LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath("vector_store")
def load_documents():
loaders = [
PyPDFLoader(source_doc_url)
if source_doc_url.endswith(".pdf")
else WebBaseLoader(source_doc_url)
for source_doc_url in st.session_state.source_doc_urls
]
documents = []
for loader in loaders:
documents.extend(loader.load())
return documents
def split_documents(documents):
text_splitter = SemanticChunker(OpenAIEmbeddings())
texts = text_splitter.split_documents(documents)
return texts
def embeddings_on_local_vectordb(texts):
vectordb = Chroma.from_documents(
texts,
embedding=OpenAIEmbeddings(),
persist_directory=LOCAL_VECTOR_STORE_DIR.as_posix(),
)
vectordb.persist()
retriever = ContextualCompressionRetriever(
base_compressor=LLMChainExtractor.from_llm(OpenAI(temperature=0)),
base_retriever=vectordb.as_retriever(search_kwargs={"k": 3}, search_type="mmr"),
)
return retriever
def query_llm(retriever, query):
qa_chain = ConversationalRetrievalChain.from_llm(
llm=OpenAIChat(),
retriever=retriever,
return_source_documents=True,
)
relevant_docs = retriever.get_relevant_documents(query)
result = qa_chain({"question": query, "chat_history": st.session_state.messages})
result = result["answer"]
st.session_state.messages.append((query, result))
return relevant_docs, result
def input_fields():
os.environ["OPENAI_API_KEY"] = "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"
st.session_state.source_doc_urls = [
url.strip() for url in st.sidebar.text_input("Source Document URLs").split(",")
]
def process_documents():
try:
documents = load_documents()
texts = split_documents(documents)
st.session_state.retriever = embeddings_on_local_vectordb(texts)
except Exception as e:
st.error(f"An error occurred: {e}")
def boot():
st.title("Enigma Chatbot")
input_fields()
st.sidebar.button("Submit Documents", on_click=process_documents)
st.sidebar.write("---")
st.sidebar.write("References made during the chat will appear here")
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
st.chat_message("human").write(message[0])
st.chat_message("ai").write(message[1])
if query := st.chat_input():
st.chat_message("human").write(query)
references, response = query_llm(st.session_state.retriever, query)
for doc in references:
st.sidebar.info(f"Page {doc.metadata['page']}\n\n{doc.page_content}")
st.chat_message("ai").write(response)
if __name__ == "__main__":
boot()