NaikPriyank commited on
Commit
3f1425d
·
verified ·
1 Parent(s): 3f33752

Upload 2 files

Browse files
Files changed (2) hide show
  1. genAI.py +273 -0
  2. requirements.txt +9 -0
genAI.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import streamlit as st
2
+ # import json
3
+ # import torch
4
+ # from transformers import AutoTokenizer, AutoModel
5
+ # import faiss
6
+ # import google.generativeai as genai
7
+ # from flashrank.Ranker import Ranker, RerankRequest
8
+
9
+ # # Configure Google Generative AI API Key
10
+ # genai.configure(api_key="AIzaSyArG3gnpZHnzi10mMSnyOMhzYJBeAZEJUs") # Replace with your API key
11
+
12
+ # # Load and preprocess the uploaded file
13
+ # def load_and_preprocess(uploaded_file):
14
+ # data = json.load(uploaded_file)
15
+ # passages = [f"Speaker: {item['speaker']}. Text: {item['text']}"
16
+ # for item in data if item["text"].strip()]
17
+ # return data, passages
18
+
19
+ # # Load embedding model
20
+ # def load_model(model_name="BAAI/bge-m3"):
21
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ # model = AutoModel.from_pretrained(model_name)
23
+ # return tokenizer, model
24
+
25
+ # # Generate embeddings
26
+ # def generate_embeddings(passages, tokenizer, model, batch_size=10, device="cuda" if torch.cuda.is_available() else "cpu"):
27
+ # model.to(device)
28
+ # embeddings = []
29
+ # for i in range(0, len(passages), batch_size):
30
+ # batch = passages[i:i + batch_size]
31
+ # inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
32
+ # with torch.no_grad():
33
+ # outputs = model(**inputs).last_hidden_state.mean(dim=1)
34
+ # embeddings.append(outputs.cpu())
35
+ # embeddings = torch.cat(embeddings, dim=0)
36
+ # return embeddings.numpy()
37
+
38
+ # # Store embeddings in FAISS
39
+ # def store_in_faiss(embeddings):
40
+ # dimension = embeddings.shape[1]
41
+ # index = faiss.IndexFlatL2(dimension)
42
+ # index.add(embeddings)
43
+ # return index
44
+
45
+ # # Retrieve top-k passages
46
+ # def retrieve_top_k(query, tokenizer, model, faiss_index, passages, k=5, device="cuda" if torch.cuda.is_available() else "cpu"):
47
+ # model.to(device)
48
+ # inputs = tokenizer([query], return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
49
+ # with torch.no_grad():
50
+ # query_embedding = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
51
+ # distances, indices = faiss_index.search(query_embedding, k)
52
+ # retrieved_passages = [passages[i] for i in indices[0]]
53
+ # return retrieved_passages
54
+
55
+ # # Rerank passages using FlashRank Ranker
56
+ # def rerank_passages(query, passages):
57
+ # formatted_passages = [{"text": passage} for passage in passages]
58
+ # ranker = Ranker(model_name="rank-T5-flan", cache_dir="/my_cache_dir") # Adjust cache directory as needed
59
+ # rerank_request = RerankRequest(query=query, passages=formatted_passages)
60
+ # results = ranker.rerank(rerank_request)
61
+ # return results
62
+
63
+ # # Generate a response using Gemini 1.5 Flash
64
+ # def generate_response(reranked_passages, query):
65
+ # context = " ".join([passage["text"] for passage in reranked_passages])
66
+ # input_text = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
67
+ # model = genai.GenerativeModel("gemini-1.5-flash")
68
+ # response = model.generate_content(input_text)
69
+ # return response.text
70
+
71
+ # # Streamlit app
72
+ # def main():
73
+ # st.set_page_config(page_title="Chatbot with Document Upload", layout="wide")
74
+ # st.title("📄 Chatbot for Minutes of Meeting")
75
+
76
+ # # Initialize session state
77
+ # if "chat_history" not in st.session_state:
78
+ # st.session_state.chat_history = []
79
+ # if "faiss_index" not in st.session_state:
80
+ # st.session_state.faiss_index = None
81
+ # if "passages" not in st.session_state:
82
+ # st.session_state.passages = None
83
+ # if "tokenizer" not in st.session_state or "model" not in st.session_state:
84
+ # st.session_state.tokenizer, st.session_state.model = load_model()
85
+
86
+ # # File uploader
87
+ # uploaded_file = st.file_uploader("Upload a JSON file for processing", type=["json"])
88
+ # if uploaded_file:
89
+ # st.write("Processing the file...")
90
+ # data, passages = load_and_preprocess(uploaded_file)
91
+ # st.session_state.passages = passages
92
+
93
+ # # Generate embeddings and store in FAISS
94
+ # tokenizer, model = st.session_state.tokenizer, st.session_state.model
95
+ # embeddings = generate_embeddings(passages, tokenizer, model)
96
+ # st.session_state.faiss_index = store_in_faiss(embeddings)
97
+ # st.success("File processed and embeddings generated successfully!")
98
+
99
+ # # Chat interface
100
+ # if st.session_state.faiss_index:
101
+ # st.header("Ask a Question")
102
+ # user_query = st.text_input("Type your question here:")
103
+ # if user_query:
104
+ # # Retrieve and rerank passages
105
+ # top_k_passages = retrieve_top_k(user_query, st.session_state.tokenizer, st.session_state.model, st.session_state.faiss_index, st.session_state.passages)
106
+ # reranked_passages = rerank_passages(user_query, top_k_passages)
107
+
108
+ # # Generate response
109
+ # response = generate_response(reranked_passages, user_query)
110
+
111
+ # # Display response
112
+ # st.markdown(f"**Question:** {user_query}")
113
+ # st.markdown(f"**Answer:** {response}")
114
+
115
+ # # Update chat history
116
+ # st.session_state.chat_history.append({"question": user_query, "answer": response})
117
+
118
+ # # Chat history
119
+ # if st.session_state.chat_history:
120
+ # st.header("Chat History")
121
+ # for chat in st.session_state.chat_history:
122
+ # st.markdown(f"**Q:** {chat['question']}")
123
+ # st.markdown(f"**A:** {chat['answer']}")
124
+
125
+ # # Run the app
126
+ # if __name__ == "__main__":
127
+ # main()
128
+
129
+ import streamlit as st
130
+ from streamlit_chat import message
131
+ import json
132
+ import torch
133
+ from transformers import AutoTokenizer, AutoModel
134
+ import faiss
135
+ import google.generativeai as genai
136
+ from flashrank.Ranker import Ranker, RerankRequest
137
+ from langchain.memory import ConversationBufferMemory
138
+ from pydantic import BaseModel,ConfigDict
139
+
140
+
141
+ genai.configure(api_key="AIzaSyArG3gnpZHnzi10mMSnyOMhzYJBeAZEJUs")
142
+
143
+ class CustomMemory(ConversationBufferMemory):
144
+ model_config = ConfigDict(arbitrary_types_allowed=True)
145
+
146
+ def load_and_preprocess(uploaded_file):
147
+ data = json.load(uploaded_file)
148
+ passages = [f"Speaker: {item['speaker']}. Text: {item['text']}"
149
+ for item in data if item["text"].strip()]
150
+ return data, passages
151
+
152
+
153
+ def load_model(model_name="BAAI/bge-m3"):
154
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
155
+ model = AutoModel.from_pretrained(model_name)
156
+ return tokenizer, model
157
+
158
+
159
+ def generate_embeddings(passages, tokenizer, model, batch_size=10, device="cuda" if torch.cuda.is_available() else "cpu"):
160
+ model.to(device)
161
+ embeddings = []
162
+ for i in range(0, len(passages), batch_size):
163
+ batch = passages[i:i + batch_size]
164
+ inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
165
+ with torch.no_grad():
166
+ outputs = model(**inputs).last_hidden_state.mean(dim=1)
167
+ embeddings.append(outputs.cpu())
168
+ embeddings = torch.cat(embeddings, dim=0)
169
+ return embeddings.numpy()
170
+
171
+
172
+ def store_in_faiss(embeddings):
173
+ dimension = embeddings.shape[1]
174
+ index = faiss.IndexFlatL2(dimension)
175
+ index.add(embeddings)
176
+ return index
177
+
178
+
179
+ def retrieve_top_k(query, tokenizer, model, faiss_index, passages, k=5, device="cuda" if torch.cuda.is_available() else "cpu"):
180
+ model.to(device)
181
+ inputs = tokenizer([query], return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
182
+ with torch.no_grad():
183
+ query_embedding = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
184
+ distances, indices = faiss_index.search(query_embedding, k)
185
+ retrieved_passages = [passages[i] for i in indices[0]]
186
+ return retrieved_passages
187
+
188
+
189
+ def rerank_passages(query, passages):
190
+ formatted_passages = [{"text": passage} for passage in passages]
191
+ ranker = Ranker(model_name="rank-T5-flan", cache_dir="/my_cache_dir") # Adjust cache directory as needed
192
+ rerank_request = RerankRequest(query=query, passages=formatted_passages)
193
+ results = ranker.rerank(rerank_request)
194
+ return results
195
+
196
+
197
+ def generate_response(context, query):
198
+ input_text = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
199
+ model = genai.GenerativeModel("gemini-1.5-flash")
200
+ response = model.generate_content(input_text)
201
+ return response.text
202
+
203
+
204
+ def handle_userinput(user_question):
205
+
206
+ top_k_passages = retrieve_top_k(user_question, st.session_state.tokenizer, st.session_state.model, st.session_state.faiss_index, st.session_state.passages)
207
+ reranked_passages = rerank_passages(user_question, top_k_passages)
208
+
209
+
210
+ context = " ".join([passage["text"] for passage in reranked_passages])
211
+
212
+
213
+ response = generate_response(context, user_question)
214
+
215
+
216
+ st.session_state.memory.chat_memory.add_user_message(user_question)
217
+ st.session_state.memory.chat_memory.add_ai_message(response)
218
+
219
+ return response
220
+
221
+
222
+ def main():
223
+ st.set_page_config(page_title="Chatbot with MoM Document Upload", layout="wide")
224
+ st.title("📄 Chatbot for Minutes of Meeting ")
225
+
226
+
227
+ if "memory" not in st.session_state:
228
+ st.session_state.memory = CustomMemory(memory_key='chat_history', return_messages=True)
229
+ if "faiss_index" not in st.session_state:
230
+ st.session_state.faiss_index = None
231
+ if "passages" not in st.session_state:
232
+ st.session_state.passages = None
233
+ if "tokenizer" not in st.session_state or "model" not in st.session_state:
234
+ st.session_state.tokenizer, st.session_state.model = load_model()
235
+
236
+
237
+ uploaded_file = st.file_uploader("Upload a JSON file for processing", type=["json"])
238
+ if uploaded_file:
239
+ st.write("Processing the file...")
240
+ data, passages = load_and_preprocess(uploaded_file)
241
+ st.session_state.passages = passages
242
+
243
+
244
+ tokenizer, model = st.session_state.tokenizer, st.session_state.model
245
+ embeddings = generate_embeddings(passages, tokenizer, model)
246
+ st.session_state.faiss_index = store_in_faiss(embeddings)
247
+ st.success("File processed and embeddings generated successfully!")
248
+
249
+
250
+ if st.session_state.faiss_index:
251
+ st.header("Ask a Question")
252
+ user_query = st.text_input("Type your question here:")
253
+ if user_query:
254
+ response = handle_userinput(user_query)
255
+
256
+
257
+ if "chat_history_ui" not in st.session_state:
258
+ st.session_state.chat_history_ui = []
259
+
260
+ st.session_state.chat_history_ui.append({"role": "user", "content": user_query})
261
+ st.session_state.chat_history_ui.append({"role": "bot", "content": response})
262
+
263
+
264
+ if "chat_history_ui" in st.session_state:
265
+ for i,chat in enumerate(st.session_state.chat_history_ui):
266
+ if chat["role"] == "user":
267
+ message(chat["content"], is_user=True,key=f"user_{i}")
268
+ else:
269
+ message(chat["content"], is_user=False,key=f"bot_{i}")
270
+
271
+
272
+ if __name__ == "__main__":
273
+ main()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.21.0
2
+ streamlit-chat==0.1.1 # If you identify the version
3
+ torch==2.5.1
4
+ transformers==4.48.0
5
+ faiss-cpu==1.9.0
6
+ google-generativeai==0.8.3
7
+ flashrank==0.2.10
8
+ langchain==0.2.17
9
+ pydantic==2.10.5