anushm commited on
Commit
2658964
1 Parent(s): 44ce708

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +60 -26
  2. creds.yaml +23 -0
  3. herbal_expert.py +295 -0
app.py CHANGED
@@ -1,36 +1,70 @@
1
  import streamlit as st
 
 
 
 
2
 
3
- # App title
4
- st.set_page_config(page_title="Herbal-Expert")
 
5
 
 
6
 
7
- # Store LLM generated responses
8
- if "messages" not in st.session_state.keys():
9
- st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
 
 
 
 
10
 
11
- # Display chat messages
12
- for message in st.session_state.messages:
13
- with st.chat_message(message["role"]):
14
- st.write(message["content"])
 
 
 
 
15
 
 
 
 
 
 
16
 
17
- # Function for generating LLM response
18
- def generate_response(prompt_input):
19
- print("someone's here")
20
- return "hello world"
21
 
22
 
23
- # User-provided prompt
24
- if prompt := st.chat_input():
25
- st.session_state.messages.append({"role": "user", "content": prompt})
26
- with st.chat_message("user"):
27
- st.write(prompt)
28
 
29
- # Generate a new response if last message is not from assistant
30
- if st.session_state.messages[-1]["role"] != "assistant":
31
- with st.chat_message("assistant"):
32
- with st.spinner("Thinking..."):
33
- response = generate_response(prompt)
34
- st.write(response)
35
- message = {"role": "assistant", "content": response}
36
- st.session_state.messages.append(message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import streamlit_authenticator as stauth
3
+ from streamlit_authenticator import Authenticate
4
+ from herbal_expert import herbal_expert
5
+ import yaml
6
 
7
+ def authentication_page():
8
+ with open('creds.yaml') as f:
9
+ creds = yaml.load(f, Loader=yaml.loader.SafeLoader)
10
 
11
+ print("creds: ", creds)
12
 
13
+ authenticator = Authenticate(
14
+ creds['credentials'],
15
+ creds['cookie']['name'],
16
+ creds['cookie']['key'],
17
+ creds['cookie']['expiry_days'],
18
+ creds['preauthorized']
19
+ )
20
 
21
+ name, authentication_status, username = authenticator.login('Login', 'main')
22
+ print("name: ", name)
23
+ print("authentication_status: ", authentication_status)
24
+ print("username: ", username)
25
+ if authentication_status:
26
+ authenticator.logout('Logout', 'main')
27
+ if st.session_state["authentication_status"]:
28
+ st.session_state.is_authenticated = True
29
 
30
+ def chatbot_page():
31
+ st.title("Herbal Expert Chatbot")
32
+ # Store LLM generated responses
33
+ if "messages" not in st.session_state.keys():
34
+ st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
35
 
36
+ # Display chat messages
37
+ for message in st.session_state.messages:
38
+ with st.chat_message(message["role"]):
39
+ st.write(message["content"])
40
 
41
 
42
+ # Function for generating LLM response
43
+ def generate_response(prompt_input):
44
+ print(st.session_state.messages)
45
+ response = herbal_expert.query_expert(prompt_input)
46
+ return response['response']
47
 
48
+
49
+ # User-provided prompt
50
+ if prompt := st.chat_input():
51
+ st.session_state.messages.append({"role": "user", "content": prompt})
52
+ with st.chat_message("user"):
53
+ st.write(prompt)
54
+
55
+ # Generate a new response if last message is not from assistant
56
+ if st.session_state.messages[-1]["role"] != "assistant":
57
+ with st.chat_message("assistant"):
58
+ with st.spinner("Thinking..."):
59
+ response = generate_response(prompt)
60
+ st.write(response)
61
+ message = {"role": "assistant", "content": response}
62
+ st.session_state.messages.append(message)
63
+
64
+ if __name__ == "__main__":
65
+ st.session_state.is_authenticated = False
66
+ authentication_page()
67
+
68
+ # Check if the user is authenticated before displaying the chatbot page
69
+ if st.session_state.is_authenticated:
70
+ chatbot_page()
creds.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ credentials:
2
+ usernames:
3
+ demotester:
4
5
+ name: Demo Tester
6
+ password: $2b$12$EalJvnUtGoMl20600.XoQOBgjJElfMIenOhsZIi3jGx7EXg.s605a #test
7
+ me:
8
9
+ name: Anush
10
+ password: $2b$12$8skMHgwa5IBnbdg7gBjLmOTbUpJdyhUz7g7xSCHIhphcSmlFUPnKS #1234
11
+ executive:
12
13
+ name: Caleb
14
+ password: $2b$12$8skMHgwa5IBnbdg7gBjLmOTbUpJdyhUz7g7xSCHIhphcSmlFUPnKS
15
+
16
+ cookie:
17
+ expiry_days: 30
18
+ key: "13141516" # Must be string
19
+ name: random_cookie_name
20
+ preauthorized:
21
+ emails:
22
23
herbal_expert.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import random
4
+
5
+ from langchain.agents import AgentExecutor, LLMSingleActionAgent, AgentOutputParser
6
+ from langchain.prompts import StringPromptTemplate
7
+ from langchain.schema import AgentAction, AgentFinish
8
+ from langchain.memory import ConversationBufferWindowMemory
9
+ from langchain import LLMChain
10
+ from langchain.llms.base import LLM
11
+ from Bio import Entrez
12
+ from requests import HTTPError
13
+ from nltk.stem import WordNetLemmatizer
14
+
15
+ Entrez.email = "[email protected]"
16
+
17
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
18
+ from typing import List, Union, Optional, Any
19
+
20
+
21
+ class CustomLLM(LLM):
22
+ n: int
23
+
24
+ @property
25
+ def _llm_type(self) -> str:
26
+ return "custom"
27
+
28
+ def _call(
29
+ self,
30
+ prompt: str,
31
+ stop: Optional[List[str]] = None,
32
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
33
+ **kwargs: Any,
34
+ ) -> str:
35
+ data = {
36
+ "messages": [
37
+ {
38
+ "role": "user",
39
+ "content": prompt
40
+ }
41
+ ],
42
+ "stop": ["### Instruction:"], "temperature": 0, "max_tokens": 512, "stream": False
43
+ }
44
+
45
+ response = requests.post("https://5423-2605-7b80-3d-320-916c-64e7-c70b-e72d.ngrok-free.app/v1/chat/completions",
46
+ headers={"Content-Type": "application/json"}, json=data)
47
+ return json.loads(response.text)['choices'][0]['message']['content']
48
+
49
+ # return make_inference_call(prompt)
50
+
51
+
52
+ class CustomPromptTemplate(StringPromptTemplate):
53
+ template: str
54
+
55
+ def format(self, **kwargs) -> str:
56
+ return self.template.format(**kwargs)
57
+
58
+
59
+ class CustomOutputParser(AgentOutputParser):
60
+ def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
61
+ return AgentFinish(return_values={"output": llm_output}, log=llm_output)
62
+
63
+
64
+ bare_output_parser = CustomOutputParser()
65
+ question_decompose_prompt = """
66
+ ### Instruction: Given the previous conversation history and the current question, pick out the relevant keywords from the question that would be used to search a medical article database.
67
+ Chat History: {history}
68
+ Question: {input}
69
+
70
+ Your response should be a list of keywords separated by commas:
71
+ ### Response:
72
+ """
73
+
74
+ prompt_with_history = CustomPromptTemplate(
75
+ template=question_decompose_prompt,
76
+ tools=[],
77
+ input_variables=["input", "history"]
78
+ )
79
+ # %%
80
+ llm = CustomLLM(n=10)
81
+ question_decompose_chain = LLMChain(llm=llm, prompt=prompt_with_history)
82
+
83
+ question_decompose_agent = LLMSingleActionAgent(
84
+ llm_chain=question_decompose_chain,
85
+ output_parser=bare_output_parser,
86
+ stop=["\nObservation:"],
87
+ allowed_tools=[]
88
+ )
89
+
90
+ memory = ConversationBufferWindowMemory(k=10)
91
+ ax_1 = AgentExecutor.from_agent_and_tools(
92
+ agent=question_decompose_agent,
93
+ tools=[],
94
+ verbose=True,
95
+ memory=memory
96
+ )
97
+
98
+
99
+ def get_num_citations(pmid: str):
100
+ citations_xml = Entrez.read(
101
+ Entrez.elink(dbfrom="pubmed", db="pmc", LinkName="pubmed_pubmed_citedin", from_uid=pmid))
102
+
103
+ for i in range(0, len(citations_xml)):
104
+ if len(citations_xml[i]["LinkSetDb"]) > 0:
105
+ pmids_list = [link["Id"] for link in citations_xml[i]["LinkSetDb"][0]["Link"]]
106
+ return len(pmids_list)
107
+ else:
108
+ return 0
109
+
110
+ def fetch_pubmed_articles(keywords, max_search=10, max_context=3):
111
+ """
112
+ The fetch_pubmed_articles function takes in a list of keywords and returns a list of articles.
113
+ The function uses the Entrez API to search for articles with the given keywords, then fetches
114
+ those articles from PubMed. The function returns a list of strings, where each string is an article.
115
+
116
+ :param keywords: Search for articles in the pubmed database
117
+ :param max_results: Specify the number of articles to be returned default is 1
118
+ :param email: Identify the user to ncbi
119
+ :return: A list of strings
120
+ """
121
+
122
+ try:
123
+ search_result = Entrez.esearch(db="pubmed", term=keywords, retmax=max_search)
124
+ id_list = Entrez.read(search_result)["IdList"]
125
+
126
+ if len(id_list) == 0:
127
+ search_result = Entrez.esearch(db="pubmed", term=keywords[:4], retmax=max_search)
128
+ id_list = Entrez.read(search_result)["IdList"]
129
+
130
+ num_citations = [(id, get_num_citations(id)) for id in id_list]
131
+ top_n_papers = sorted(num_citations, key=lambda x: x[1], reverse=True)[:max_context]
132
+ print(f"top_{max_context}_papers: ", top_n_papers)
133
+
134
+ top_n_papers = [paper[0] for paper in top_n_papers]
135
+ fetch_handle = Entrez.efetch(db="pubmed", id=top_n_papers, rettype="medline", retmode="xml")
136
+ fetched_articles = Entrez.read(fetch_handle)
137
+
138
+ articles = []
139
+ # somehow only pull natural therapeutic articles
140
+ for fetched in fetched_articles['PubmedArticle']:
141
+ title = fetched['MedlineCitation']['Article']['ArticleTitle']
142
+ abstract = fetched['MedlineCitation']['Article']['Abstract']['AbstractText'][0] if 'Abstract' in fetched[
143
+ 'MedlineCitation']['Article'] else "No Abstract"
144
+ # pmid = fetched['MedlineCitation']['PMID']
145
+ articles.append(title + "\n" + abstract)
146
+
147
+ return articles
148
+ except HTTPError as e:
149
+ print("HTTPError: ", e)
150
+ return []
151
+ except RuntimeError as e:
152
+ print("RuntimeError: ", e)
153
+ return []
154
+
155
+
156
+ def call_model_with_history(messages: list):
157
+ """
158
+ The call_model_with_history function takes a list of messages and returns the next message in the conversation.
159
+
160
+ :param messages: list: Pass the history of messages to the model
161
+ :return: the text of the model's reply
162
+ """
163
+ data = {
164
+ "messages": messages,
165
+ "stop": ["### Instruction:"], "temperature": 0, "max_tokens": 512, "stream": False
166
+ }
167
+
168
+ response = requests.post("https://5423-2605-7b80-3d-320-916c-64e7-c70b-e72d.ngrok-free.app/v1/chat/completions", headers={"Content-Type": "application/json"}, json=data)
169
+ return json.loads(response.text)['choices'][0]['message']['content']
170
+
171
+
172
+
173
+ # TODO: add ability to pass message history to model
174
+ def format_prompt_and_query(prompt, **kwargs):
175
+ """
176
+ The format_prompt_and_query function takes a prompt and keyword arguments, formats the prompt with the keyword
177
+ arguments, and then calls call_model_with_history with a list of messages containing the formatted prompt.
178
+
179
+ :param prompt: Format the prompt with the values in kwargs
180
+ :param **kwargs: Pass a dictionary of key-value pairs to the formatting function
181
+ :return: A list of dictionaries
182
+ """
183
+
184
+ formatted_prompt = prompt.format(**kwargs)
185
+
186
+ messages = [
187
+ {"role": "system", "content": "Perform the instructions to the best of your ability."},
188
+ {"role": "user", "content": formatted_prompt}
189
+ ]
190
+
191
+ return call_model_with_history(messages)
192
+
193
+
194
+ class HerbalExpert:
195
+ def __init__(self, qd_chain):
196
+ self.qd_chain = qd_chain
197
+ self.wnl = WordNetLemmatizer()
198
+ self.default_questions = [
199
+ "How is chamomile traditionally used in herbal medicine?",
200
+ "What are the potential side effects or interactions of consuming echinacea?",
201
+ "Can you explain the different methods of consuming lavender for health benefits?",
202
+ "Which herbs are commonly known for their anti-inflammatory properties?",
203
+ "I'm experiencing consistent stress and anxiety. What herbs or supplements could help alleviate these symptoms?",
204
+ "Are there any natural herbs that could support better sleep?",
205
+ "What cannabis or hemp products would you recommend for chronic pain relief?",
206
+ "I'm looking to boost my immune system. Are there any specific herbs or supplements that could help?",
207
+ "Which herbs or supplements are recommended for enhancing cognitive functions and memory?"
208
+ ]
209
+ # og = Original, qa = Question Asking, ri = Response Improvement
210
+ self.prompts = {
211
+ "og_answer_prompt": """### Instruction: Answer the following question using the given context. Question: {question}
212
+ Answer: ### Response: """,
213
+
214
+ "ans_decompose_prompt": """### Instruction: Given the following text, identify the 2 most important
215
+ keywords that capture the essence of the text. If there's a list of products, choose the top 2 products.
216
+ Your response should be a list of only 2 keywords separated by commas. Text: {original_answer} Keywords:
217
+ ### Response: """,
218
+
219
+ "qa_prompt": """### Instruction: Answer the following question using the given context.
220
+ Question: {question}
221
+ Context: {context}
222
+ ### Response: """,
223
+
224
+ "ri_prompt": """### Instruction: You are an caring, intelligent question answering agent. Craft a
225
+ response that is more informative and intelligent than the original answer and imparts knowledge from
226
+ both the old answer and from the context only if it helps answer the question.
227
+ Question: {question}
228
+ Old Answer: {answer}
229
+ Context: {answer2}
230
+ Improved answer: ### Response:"""
231
+ }
232
+
233
+ def process_query_words(self, question_words: str, answer_words: str):
234
+ # don't need to be searching for these in pubmed. Should we include: 'supplements', 'supplement'
235
+ vague_words = ['recommendation', 'recommendations', 'products', 'product']
236
+ words = question_words.lower().split(",") + answer_words.lower().split(",")
237
+
238
+ final_list = []
239
+ for word in words:
240
+ cleaned = word.strip().strip('"')
241
+ if cleaned not in vague_words:
242
+ final_list.append(self.wnl.lemmatize(cleaned))
243
+
244
+ return list(set(final_list))
245
+
246
+ def convert_question_into_words(self, question: str):
247
+ original_answer = format_prompt_and_query(self.prompts["og_answer_prompt"], question=question)
248
+ print("Original Answer: ", original_answer)
249
+
250
+ question_decompose = self.qd_chain.run(question)
251
+ print("Question Decompose: ", question_decompose)
252
+
253
+ original_answer_decompose = format_prompt_and_query(self.prompts["ans_decompose_prompt"],
254
+ original_answer=original_answer)
255
+ print("Original Answer Decomposed: ", original_answer_decompose)
256
+
257
+ words = self.process_query_words(question_decompose, original_answer_decompose)
258
+ return words, original_answer
259
+
260
+ def query_expert(self, question: str = None):
261
+ question = self.default_questions[
262
+ random.randint(0, len(self.default_questions) - 1)] if question is None else question
263
+ print("Question: ", question)
264
+
265
+ keywords, original_response = self.convert_question_into_words(question)
266
+ print("Keywords: ", keywords)
267
+
268
+ context = fetch_pubmed_articles(" AND ".join(keywords), max_search=5)
269
+
270
+ if len(context) == 0:
271
+ return {
272
+ "question": question,
273
+ "response": original_response,
274
+ "info": "No context found"
275
+ }
276
+
277
+ contextual_response = format_prompt_and_query(self.prompts["qa_prompt"], question=question, context=context)
278
+ improved_response = format_prompt_and_query(self.prompts["ri_prompt"], question=question,
279
+ answer=original_response, answer2=contextual_response)
280
+
281
+ return {
282
+ "question": question,
283
+ "response": improved_response,
284
+ "info": "Success"
285
+ }
286
+
287
+
288
+ herbal_expert = HerbalExpert(ax_1)
289
+
290
+
291
+ if __name__ == '__main__':
292
+ herbal_expert = HerbalExpert(ax_1)
293
+ answer = herbal_expert.query_expert("I'm experiencing consistent stress and anxiety. What herbs or supplements could help alleviate these symptoms?")
294
+ print(answer['response'])
295
+ # return to api? who knows