|
import os, tempfile |
|
|
|
from pathlib import Path |
|
import traceback |
|
from langchain.chains import RetrievalQA, ConversationalRetrievalChain |
|
from langchain.embeddings import OpenAIEmbeddings |
|
from langchain.vectorstores import Chroma |
|
from langchain import OpenAI |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain.document_loaders import DirectoryLoader |
|
from langchain.text_splitter import CharacterTextSplitter |
|
from langchain.vectorstores import Chroma |
|
from langchain.embeddings.openai import OpenAIEmbeddings |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory |
|
from dotenv import load_dotenv |
|
import streamlit as st |
|
|
|
load_dotenv() |
|
TMP_DIR = Path(__file__).resolve().parent.joinpath('data', 'tmp') |
|
LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath('data', 'vector_store') |
|
|
|
|
|
|
|
|
|
os.makedirs(TMP_DIR, exist_ok=True) |
|
os.makedirs(LOCAL_VECTOR_STORE_DIR, exist_ok=True) |
|
|
|
|
|
|
|
os.makedirs(TMP_DIR, exist_ok=True) |
|
os.makedirs(LOCAL_VECTOR_STORE_DIR, exist_ok=True) |
|
st.set_page_config(page_title="RAG") |
|
st.title("Retrieval Augmented Generation Engine") |
|
|
|
openai_api_key = os.environ.get('OPENAI_API_KEY') |
|
st.session_state.openai_api_key = openai_api_key |
|
|
|
def load_documents(): |
|
loader = DirectoryLoader(TMP_DIR.as_posix(), glob='**/*.pdf') |
|
documents = loader.load() |
|
return documents |
|
|
|
def split_documents(documents): |
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) |
|
texts = text_splitter.split_documents(documents) |
|
return texts |
|
|
|
def embeddings_on_local_vectordb(): |
|
|
|
|
|
vectordb=Chroma(persist_directory=LOCAL_VECTOR_STORE_DIR.as_posix(), embedding_function=OpenAIEmbeddings()) |
|
vectordb.persist() |
|
retriever = vectordb.as_retriever(search_kwargs={'k': 5}) |
|
return retriever |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def query_llm(retriever, query): |
|
try: |
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm=ChatOpenAI(temperature=0, openai_api_key=st.session_state.openai_api_key), |
|
retriever=retriever, |
|
return_source_documents=True, |
|
) |
|
result = qa_chain({'question': query, 'chat_history': st.session_state.messages}) |
|
result = result.get('answer') |
|
except Exception as e: |
|
print(f"Exception {e} with traceback : {traceback.format_exc() } occurred for API key: {st.session_state.openai_api_key}") |
|
result = "" |
|
st.session_state.messages.append((query, result)) |
|
return result |
|
|
|
def input_fields(): |
|
|
|
with st.sidebar: |
|
|
|
openai_key = st.text_input("OpenAI API key", type="password") |
|
if openai_key != "": |
|
st.session_state.openai_api_key = openai_key |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.session_state.source_docs = st.file_uploader(label="Upload Documents", type="pdf", accept_multiple_files=True) |
|
|
|
|
|
retriever = embeddings_on_local_vectordb() |
|
|
|
def process_documents(): |
|
|
|
if not st.session_state.openai_api_key or not st.session_state.source_docs: |
|
st.warning(f"Please upload the documents and provide the missing fields.") |
|
else: |
|
try: |
|
for source_doc in st.session_state.source_docs: |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, dir=TMP_DIR.as_posix(), suffix='.pdf') as tmp_file: |
|
tmp_file.write(source_doc.read()) |
|
|
|
documents = load_documents() |
|
|
|
for _file in TMP_DIR.iterdir(): |
|
temp_file = TMP_DIR.joinpath(_file) |
|
temp_file.unlink() |
|
|
|
texts = split_documents(documents) |
|
|
|
print(f"Adding {len(texts)} texts to vector DB") |
|
retriever.add_texts(texts) |
|
retriever.persist() |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {e}") |
|
|
|
def boot(): |
|
|
|
input_fields() |
|
|
|
st.button("Submit Documents", on_click=process_documents) |
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
for message in st.session_state.messages: |
|
st.chat_message('human').write(message[0]) |
|
st.chat_message('ai').write(message[1]) |
|
|
|
if query := st.chat_input(): |
|
st.chat_message("human").write(query) |
|
response = query_llm(retriever, query) |
|
st.chat_message("ai").write(response) |
|
|
|
if __name__ == '__main__': |
|
|
|
boot() |
|
|