itsmendas commited on
Commit
25db1d7
·
verified ·
1 Parent(s): 7ac5f1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -0
app.py CHANGED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, tempfile
2
+ # import pinecone
3
+ from pathlib import Path
4
+ import traceback
5
+ from langchain.chains import RetrievalQA, ConversationalRetrievalChain
6
+ from langchain.embeddings import OpenAIEmbeddings
7
+ from langchain.vectorstores import Chroma
8
+ from langchain import OpenAI
9
+ from langchain.chat_models import ChatOpenAI
10
+ from langchain.document_loaders import DirectoryLoader
11
+ from langchain.text_splitter import CharacterTextSplitter
12
+ from langchain.vectorstores import Chroma
13
+ from langchain.embeddings.openai import OpenAIEmbeddings
14
+ from langchain.memory import ConversationBufferMemory
15
+ from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
16
+ from dotenv import load_dotenv
17
+ import streamlit as st
18
+
19
+ load_dotenv()
20
+ TMP_DIR = Path(__file__).resolve().parent.joinpath('data', 'tmp')
21
+ LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath('data', 'vector_store')
22
+
23
+
24
+
25
+ # Load environment variables
26
+ os.makedirs(TMP_DIR, exist_ok=True)
27
+ os.makedirs(LOCAL_VECTOR_STORE_DIR, exist_ok=True)
28
+
29
+
30
+
31
+ os.makedirs(TMP_DIR, exist_ok=True)
32
+ os.makedirs(LOCAL_VECTOR_STORE_DIR, exist_ok=True)
33
+ st.set_page_config(page_title="RAG")
34
+ st.title("Retrieval Augmented Generation Engine")
35
+
36
+ openai_api_key = os.environ.get('OPENAI_API_KEY')
37
+ st.session_state.openai_api_key = openai_api_key
38
+
39
+ def load_documents():
40
+ loader = DirectoryLoader(TMP_DIR.as_posix(), glob='**/*.pdf')
41
+ documents = loader.load()
42
+ return documents
43
+
44
+ def split_documents(documents):
45
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
46
+ texts = text_splitter.split_documents(documents)
47
+ return texts
48
+
49
+ def embeddings_on_local_vectordb():
50
+ # vectordb = Chroma.from_documents(texts, embedding=OpenAIEmbeddings(),
51
+ # persist_directory=LOCAL_VECTOR_STORE_DIR.as_posix())
52
+ vectordb=Chroma(persist_directory=LOCAL_VECTOR_STORE_DIR.as_posix(), embedding_function=OpenAIEmbeddings())
53
+ vectordb.persist()
54
+ retriever = vectordb.as_retriever(search_kwargs={'k': 5})
55
+ return retriever
56
+
57
+ # def embeddings_on_pinecone(texts):
58
+ # pinecone.init(api_key=st.session_state.pinecone_api_key, environment=st.session_state.pinecone_env)
59
+ # embeddings = OpenAIEmbeddings(openai_api_key=st.session_state.openai_api_key)
60
+ # vectordb = Pinecone.from_documents(texts, embeddings, index_name=st.session_state.pinecone_index)
61
+ # retriever = vectordb.as_retriever()
62
+ # return retriever
63
+
64
+ def query_llm(retriever, query):
65
+ try:
66
+ qa_chain = ConversationalRetrievalChain.from_llm(
67
+ llm=ChatOpenAI(temperature=0, openai_api_key=st.session_state.openai_api_key),
68
+ retriever=retriever,
69
+ return_source_documents=True,
70
+ )
71
+ result = qa_chain({'question': query, 'chat_history': st.session_state.messages})
72
+ result = result.get('answer')
73
+ except Exception as e:
74
+ print(f"Exception {e} with traceback : {traceback.format_exc() } occurred for API key: {st.session_state.openai_api_key}")
75
+ result = ""
76
+ st.session_state.messages.append((query, result))
77
+ return result
78
+
79
+ def input_fields():
80
+ #
81
+ with st.sidebar:
82
+ #
83
+ openai_key = st.text_input("OpenAI API key", type="password")
84
+ if openai_key != "":
85
+ st.session_state.openai_api_key = openai_key
86
+ #
87
+ # if "pinecone_api_key" in st.secrets:
88
+ # st.session_state.pinecone_api_key = st.secrets.pinecone_api_key
89
+ # else:
90
+ # st.session_state.pinecone_api_key = st.text_input("Pinecone API key", type="password")
91
+ #
92
+ # if "pinecone_env" in st.secrets:
93
+ # st.session_state.pinecone_env = st.secrets.pinecone_env
94
+ # else:
95
+ # st.session_state.pinecone_env = st.text_input("Pinecone environment")
96
+ #
97
+ # if "pinecone_index" in st.secrets:
98
+ # st.session_state.pinecone_index = st.secrets.pinecone_index
99
+ # else:
100
+ # st.session_state.pinecone_index = st.text_input("Pinecone index name")
101
+ #
102
+ # st.session_state.pinecone_db = st.toggle('Use Pinecone Vector DB')
103
+ #
104
+ st.session_state.source_docs = st.file_uploader(label="Upload Documents", type="pdf", accept_multiple_files=True)
105
+ #
106
+
107
+ retriever = embeddings_on_local_vectordb()
108
+
109
+ def process_documents():
110
+ # if not st.session_state.openai_api_key or not st.session_state.pinecone_api_key or not st.session_state.pinecone_env or not st.session_state.pinecone_index or not st.session_state.source_docs:
111
+ if not st.session_state.openai_api_key or not st.session_state.source_docs:
112
+ st.warning(f"Please upload the documents and provide the missing fields.")
113
+ else:
114
+ try:
115
+ for source_doc in st.session_state.source_docs:
116
+ #
117
+ with tempfile.NamedTemporaryFile(delete=False, dir=TMP_DIR.as_posix(), suffix='.pdf') as tmp_file:
118
+ tmp_file.write(source_doc.read())
119
+ #
120
+ documents = load_documents()
121
+ #
122
+ for _file in TMP_DIR.iterdir():
123
+ temp_file = TMP_DIR.joinpath(_file)
124
+ temp_file.unlink()
125
+ #
126
+ texts = split_documents(documents)
127
+
128
+ print(f"Adding {len(texts)} texts to vector DB")
129
+ retriever.add_texts(texts)
130
+ retriever.persist()
131
+ #
132
+ # if not st.session_state.pinecone_db:
133
+ # st.session_state.retriever = retriever
134
+ # else:
135
+ # st.session_state.retriever = embeddings_on_pinecone(texts)
136
+ except Exception as e:
137
+ st.error(f"An error occurred: {e}")
138
+
139
+ def boot():
140
+ #
141
+ input_fields()
142
+ #
143
+ st.button("Submit Documents", on_click=process_documents)
144
+ #
145
+ if "messages" not in st.session_state:
146
+ st.session_state.messages = []
147
+ #
148
+ for message in st.session_state.messages:
149
+ st.chat_message('human').write(message[0])
150
+ st.chat_message('ai').write(message[1])
151
+ #
152
+ if query := st.chat_input():
153
+ st.chat_message("human").write(query)
154
+ response = query_llm(retriever, query)
155
+ st.chat_message("ai").write(response)
156
+
157
+ if __name__ == '__main__':
158
+ #
159
+ boot()
160
+