Ritesh-hf commited on
Commit
e635e64
1 Parent(s): 24e37bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -64
app.py CHANGED
@@ -26,6 +26,7 @@ import os
26
  load_dotenv(".env")
27
  USER_AGENT = os.getenv("USER_AGENT")
28
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
 
29
  SECRET_KEY = os.getenv("SECRET_KEY")
30
  SESSION_ID_DEFAULT = "abc123"
31
 
@@ -33,6 +34,7 @@ SESSION_ID_DEFAULT = "abc123"
33
  # Set environment variables
34
  os.environ['USER_AGENT'] = USER_AGENT
35
  os.environ["GROQ_API_KEY"] = GROQ_API_KEY
 
36
  os.environ["TOKENIZERS_PARALLELISM"] = 'true'
37
 
38
  # Initialize Flask app and SocketIO with CORS
@@ -43,94 +45,130 @@ app.config['SESSION_COOKIE_SECURE'] = True # Use HTTPS
43
  app.config['SESSION_COOKIE_HTTPONLY'] = True
44
  app.config['SECRET_KEY'] = SECRET_KEY
45
 
 
 
 
46
 
47
- embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-multilingual-base", model_kwargs={"trust_remote_code":True})
48
- llm = ChatGroq(
49
- model="llama-3.1-70b-versatile",
50
- temperature=0.0,
51
- max_tokens=1024,
52
- max_retries=2
53
- )
54
-
55
- excel_vectorstore = FAISS.load_local(folder_path="./faiss_excel_doc_index", embeddings=embed_model, allow_dangerous_deserialization=True)
56
- word_vectorstore = FAISS.load_local(folder_path="./faiss_recursive_split_word_doc_index", embeddings=embed_model, allow_dangerous_deserialization=True)
57
- excel_vectorstore.merge_from(word_vectorstore)
58
- combined_vectorstore = excel_vectorstore
59
 
60
- with open('combined_recursive_keyword_retriever.pkl', 'rb') as f:
61
- combined_keyword_retriever = pickle.load(f)
62
- combined_keyword_retriever.k = 1000
63
 
64
- semantic_retriever = combined_vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 100})
65
-
66
-
67
- # initialize the ensemble retriever
68
- ensemble_retriever = EnsembleRetriever(
69
- retrievers=[combined_keyword_retriever, semantic_retriever], weights=[0.5, 0.5]
 
 
70
  )
71
 
 
 
72
 
73
- embeddings_filter = EmbeddingsFilter(embeddings=embed_model, similarity_threshold=0.4)
74
- compression_retriever = ContextualCompressionRetriever(
75
- base_compressor=embeddings_filter, base_retriever=semantic_retriever
 
 
 
 
 
 
 
 
 
76
  )
77
-
78
- template = """
79
- User Instructions:
80
-
81
- You are an Arabic AI Assistant focused on providing clear, concise responses.
82
- Always answer truthfully. If the user query is irrelevant to the provided CONTEXT, respond stating the reason.
83
- For general questions like greetings reply with formal greetings.
84
- Generate responses in Arabic. Format any English words and numbers appropriately for clarity.
85
-
86
- Round off numbers with decimal integers to two decimal integers.
87
-
88
- Use numbered lists where applicable for better organization.
89
-
90
- Provide detailed yet concise answers, covering all important aspects.
91
- Remember, responding outside the CONTEXT may lead to the termination of the interaction.
92
- CONTEXT: {context}
93
- Query: {question}
94
-
95
- After generating your response, ensure proper formatting and text direction of Arabic and English words/numbers. Return only the AI-generated answer.
 
 
 
 
 
 
 
 
96
  """
97
-
98
- prompt = ChatPromptTemplate.from_template(template)
99
- output_parser = StrOutputParser()
100
-
101
- def format_docs(docs):
102
- return "\n\n".join(doc.page_content for doc in docs)
103
-
104
-
105
- rag_chain = (
106
- {"context": compression_retriever.with_config(run_name="Docs") | format_docs, "question": RunnablePassthrough()}
107
- | prompt
108
- | llm
109
- | output_parser
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  )
111
 
112
  # Function to handle WebSocket connection
113
  @socketio.on('connect')
114
  def handle_connect():
115
- emit('connection_response', {'message': 'Connected successfully.'}, room=request.sid)
116
-
117
- @socketio.on('ping')
118
- def handle_ping(data):
119
- emit('ping_response', {'message': 'Healthy Connection.'}, room=request.sid)
120
 
121
  # Function to handle WebSocket disconnection
122
  @socketio.on('disconnect')
123
  def handle_disconnect():
124
- emit('connection_response', {'message': 'Disconnected successfully.'})
 
125
 
126
  # Function to handle WebSocket messages
127
  @socketio.on('message')
128
  def handle_message(data):
129
  question = data.get('question')
 
 
 
 
 
 
130
  try:
131
- for chunk in rag_chain.stream(question):
132
- emit('response', chunk, room=request.sid)
 
 
 
133
  except Exception as e:
 
134
  emit('response', {"error": "An error occurred while processing your request."}, room=request.sid)
135
 
136
 
 
26
  load_dotenv(".env")
27
  USER_AGENT = os.getenv("USER_AGENT")
28
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
29
+ PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
30
  SECRET_KEY = os.getenv("SECRET_KEY")
31
  SESSION_ID_DEFAULT = "abc123"
32
 
 
34
  # Set environment variables
35
  os.environ['USER_AGENT'] = USER_AGENT
36
  os.environ["GROQ_API_KEY"] = GROQ_API_KEY
37
+ os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY
38
  os.environ["TOKENIZERS_PARALLELISM"] = 'true'
39
 
40
  # Initialize Flask app and SocketIO with CORS
 
45
  app.config['SESSION_COOKIE_HTTPONLY'] = True
46
  app.config['SECRET_KEY'] = SECRET_KEY
47
 
48
+ # Initialize Pinecone index and BM25 encoder
49
+ pinecone_index = initialize_pinecone("uae-national-library-and-archives-vectorstore")
50
+ bm25 = BM25Encoder().load("./UAE-NLA.json")
51
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ old_embed_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
 
 
54
 
55
+ # Initialize models and retriever
56
+ embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-multilingual-base", model_kwargs={"trust_remote_code":True})
57
+ retriever = PineconeHybridSearchRetriever(
58
+ embeddings=embed_model,
59
+ sparse_encoder=bm25,
60
+ index=pinecone_index,
61
+ top_k=50,
62
+ alpha=0.5
63
  )
64
 
65
+ # Initialize LLM
66
+ llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, max_tokens=1024, max_retries=2)
67
 
68
+ # Contextualization prompt and retriever
69
+ contextualize_q_system_prompt = """Given a chat history and the latest user question \
70
+ which might reference context in the chat history, formulate a standalone question \
71
+ which can be understood without the chat history. Do NOT answer the question, \
72
+ just reformulate it if needed and otherwise return it as is.
73
+ """
74
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
75
+ [
76
+ ("system", contextualize_q_system_prompt),
77
+ MessagesPlaceholder("chat_history"),
78
+ ("human", "{input}")
79
+ ]
80
  )
81
+ history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
82
+
83
+ # QA system prompt and chain
84
+ qa_system_prompt = """You are a highly skilled information retrieval assistant. Use the following context to answer questions effectively. \
85
+ If you don't know the answer, simply state that you don't know. \
86
+ Your answer should be in {language} language. \
87
+ Provide answers in proper HTML format and keep them concise. \
88
+ When responding to queries, follow these guidelines: \
89
+ 1. Provide Clear Answers: \
90
+ - Ensure the response directly addresses the query with accurate and relevant information.\
91
+ 2. Include Detailed References: \
92
+ - Links to Sources: Include URLs to credible sources where users can verify information or explore further. \
93
+ - Reference Sites: Mention specific websites or platforms that offer additional information. \
94
+ - Downloadable Materials: Provide links to any relevant downloadable resources if applicable. \
95
+
96
+ 3. Formatting for Readability: \
97
+ - The answer should be in a proper HTML format with appropriate tags. \
98
+ - For arabic language response align the text to right and convert numbers also.
99
+ - Double check if the language of answer is correct or not.
100
+ - Use bullet points or numbered lists where applicable to present information clearly. \
101
+ - Highlight key details using bold or italics. \
102
+ - Provide proper and meaningful abbreviations for urls. Do not include naked urls. \
103
+
104
+ 4. Organize Content Logically: \
105
+ - Structure the content in a logical order, ensuring easy navigation and understanding for the user. \
106
+
107
+ {context}
108
  """
109
+ qa_prompt = ChatPromptTemplate.from_messages(
110
+ [
111
+ ("system", qa_system_prompt),
112
+ MessagesPlaceholder("chat_history"),
113
+ ("human", "{input}")
114
+ ]
115
+ )
116
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
117
+
118
+ # Retrieval and Generative (RAG) Chain
119
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
120
+
121
+ # Chat message history storage
122
+ store = {}
123
+
124
+ def clean_temporary_data():
125
+ store.clear()
126
+
127
+ def get_session_history(session_id: str) -> BaseChatMessageHistory:
128
+ if session_id not in store:
129
+ store[session_id] = ChatMessageHistory()
130
+ return store[session_id]
131
+
132
+ # Conversational RAG chain with message history
133
+ conversational_rag_chain = RunnableWithMessageHistory(
134
+ rag_chain,
135
+ get_session_history,
136
+ input_messages_key="input",
137
+ history_messages_key="chat_history",
138
+ language_message_key="language",
139
+ output_messages_key="answer",
140
  )
141
 
142
  # Function to handle WebSocket connection
143
  @socketio.on('connect')
144
  def handle_connect():
145
+ print(f"Client connected: {request.sid}")
146
+ emit('connection_response', {'message': 'Connected successfully.'})
 
 
 
147
 
148
  # Function to handle WebSocket disconnection
149
  @socketio.on('disconnect')
150
  def handle_disconnect():
151
+ print(f"Client disconnected: {request.sid}")
152
+ clean_temporary_data()
153
 
154
  # Function to handle WebSocket messages
155
  @socketio.on('message')
156
  def handle_message(data):
157
  question = data.get('question')
158
+ language = data.get('language')
159
+ if "en" in language:
160
+ language = "English"
161
+ else:
162
+ language = "Arabic"
163
+ session_id = data.get('session_id', SESSION_ID_DEFAULT)
164
  try:
165
+ for chunk in conversational_rag_chain.stream(
166
+ {"input": question, 'language': language},
167
+ config={"configurable": {"session_id": session_id}},
168
+ ):
169
+ emit('response', chunk['answer'], room=request.sid)
170
  except Exception as e:
171
+ print(f"Error during message handling: {e}")
172
  emit('response', {"error": "An error occurred while processing your request."}, room=request.sid)
173
 
174