Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- app.py +60 -26
- creds.yaml +23 -0
- herbal_expert.py +295 -0
app.py
CHANGED
@@ -1,36 +1,70 @@
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
5 |
|
|
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
15 |
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
#
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
|
22 |
|
23 |
-
#
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
st.write(
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import streamlit_authenticator as stauth
|
3 |
+
from streamlit_authenticator import Authenticate
|
4 |
+
from herbal_expert import herbal_expert
|
5 |
+
import yaml
|
6 |
|
7 |
+
def authentication_page():
|
8 |
+
with open('creds.yaml') as f:
|
9 |
+
creds = yaml.load(f, Loader=yaml.loader.SafeLoader)
|
10 |
|
11 |
+
print("creds: ", creds)
|
12 |
|
13 |
+
authenticator = Authenticate(
|
14 |
+
creds['credentials'],
|
15 |
+
creds['cookie']['name'],
|
16 |
+
creds['cookie']['key'],
|
17 |
+
creds['cookie']['expiry_days'],
|
18 |
+
creds['preauthorized']
|
19 |
+
)
|
20 |
|
21 |
+
name, authentication_status, username = authenticator.login('Login', 'main')
|
22 |
+
print("name: ", name)
|
23 |
+
print("authentication_status: ", authentication_status)
|
24 |
+
print("username: ", username)
|
25 |
+
if authentication_status:
|
26 |
+
authenticator.logout('Logout', 'main')
|
27 |
+
if st.session_state["authentication_status"]:
|
28 |
+
st.session_state.is_authenticated = True
|
29 |
|
30 |
+
def chatbot_page():
|
31 |
+
st.title("Herbal Expert Chatbot")
|
32 |
+
# Store LLM generated responses
|
33 |
+
if "messages" not in st.session_state.keys():
|
34 |
+
st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
|
35 |
|
36 |
+
# Display chat messages
|
37 |
+
for message in st.session_state.messages:
|
38 |
+
with st.chat_message(message["role"]):
|
39 |
+
st.write(message["content"])
|
40 |
|
41 |
|
42 |
+
# Function for generating LLM response
|
43 |
+
def generate_response(prompt_input):
|
44 |
+
print(st.session_state.messages)
|
45 |
+
response = herbal_expert.query_expert(prompt_input)
|
46 |
+
return response['response']
|
47 |
|
48 |
+
|
49 |
+
# User-provided prompt
|
50 |
+
if prompt := st.chat_input():
|
51 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
52 |
+
with st.chat_message("user"):
|
53 |
+
st.write(prompt)
|
54 |
+
|
55 |
+
# Generate a new response if last message is not from assistant
|
56 |
+
if st.session_state.messages[-1]["role"] != "assistant":
|
57 |
+
with st.chat_message("assistant"):
|
58 |
+
with st.spinner("Thinking..."):
|
59 |
+
response = generate_response(prompt)
|
60 |
+
st.write(response)
|
61 |
+
message = {"role": "assistant", "content": response}
|
62 |
+
st.session_state.messages.append(message)
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
st.session_state.is_authenticated = False
|
66 |
+
authentication_page()
|
67 |
+
|
68 |
+
# Check if the user is authenticated before displaying the chatbot page
|
69 |
+
if st.session_state.is_authenticated:
|
70 |
+
chatbot_page()
|
creds.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
credentials:
|
2 |
+
usernames:
|
3 |
+
demotester:
|
4 |
+
email: [email protected]
|
5 |
+
name: Demo Tester
|
6 |
+
password: $2b$12$EalJvnUtGoMl20600.XoQOBgjJElfMIenOhsZIi3jGx7EXg.s605a #test
|
7 |
+
me:
|
8 |
+
email: [email protected]
|
9 |
+
name: Anush
|
10 |
+
password: $2b$12$8skMHgwa5IBnbdg7gBjLmOTbUpJdyhUz7g7xSCHIhphcSmlFUPnKS #1234
|
11 |
+
executive:
|
12 |
+
email: [email protected]
|
13 |
+
name: Caleb
|
14 |
+
password: $2b$12$8skMHgwa5IBnbdg7gBjLmOTbUpJdyhUz7g7xSCHIhphcSmlFUPnKS
|
15 |
+
|
16 |
+
cookie:
|
17 |
+
expiry_days: 30
|
18 |
+
key: "13141516" # Must be string
|
19 |
+
name: random_cookie_name
|
20 |
+
preauthorized:
|
21 |
+
emails:
|
22 | |
23 |
herbal_expert.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
|
5 |
+
from langchain.agents import AgentExecutor, LLMSingleActionAgent, AgentOutputParser
|
6 |
+
from langchain.prompts import StringPromptTemplate
|
7 |
+
from langchain.schema import AgentAction, AgentFinish
|
8 |
+
from langchain.memory import ConversationBufferWindowMemory
|
9 |
+
from langchain import LLMChain
|
10 |
+
from langchain.llms.base import LLM
|
11 |
+
from Bio import Entrez
|
12 |
+
from requests import HTTPError
|
13 |
+
from nltk.stem import WordNetLemmatizer
|
14 |
+
|
15 |
+
Entrez.email = "[email protected]"
|
16 |
+
|
17 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
18 |
+
from typing import List, Union, Optional, Any
|
19 |
+
|
20 |
+
|
21 |
+
class CustomLLM(LLM):
|
22 |
+
n: int
|
23 |
+
|
24 |
+
@property
|
25 |
+
def _llm_type(self) -> str:
|
26 |
+
return "custom"
|
27 |
+
|
28 |
+
def _call(
|
29 |
+
self,
|
30 |
+
prompt: str,
|
31 |
+
stop: Optional[List[str]] = None,
|
32 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
33 |
+
**kwargs: Any,
|
34 |
+
) -> str:
|
35 |
+
data = {
|
36 |
+
"messages": [
|
37 |
+
{
|
38 |
+
"role": "user",
|
39 |
+
"content": prompt
|
40 |
+
}
|
41 |
+
],
|
42 |
+
"stop": ["### Instruction:"], "temperature": 0, "max_tokens": 512, "stream": False
|
43 |
+
}
|
44 |
+
|
45 |
+
response = requests.post("https://5423-2605-7b80-3d-320-916c-64e7-c70b-e72d.ngrok-free.app/v1/chat/completions",
|
46 |
+
headers={"Content-Type": "application/json"}, json=data)
|
47 |
+
return json.loads(response.text)['choices'][0]['message']['content']
|
48 |
+
|
49 |
+
# return make_inference_call(prompt)
|
50 |
+
|
51 |
+
|
52 |
+
class CustomPromptTemplate(StringPromptTemplate):
|
53 |
+
template: str
|
54 |
+
|
55 |
+
def format(self, **kwargs) -> str:
|
56 |
+
return self.template.format(**kwargs)
|
57 |
+
|
58 |
+
|
59 |
+
class CustomOutputParser(AgentOutputParser):
|
60 |
+
def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
|
61 |
+
return AgentFinish(return_values={"output": llm_output}, log=llm_output)
|
62 |
+
|
63 |
+
|
64 |
+
bare_output_parser = CustomOutputParser()
|
65 |
+
question_decompose_prompt = """
|
66 |
+
### Instruction: Given the previous conversation history and the current question, pick out the relevant keywords from the question that would be used to search a medical article database.
|
67 |
+
Chat History: {history}
|
68 |
+
Question: {input}
|
69 |
+
|
70 |
+
Your response should be a list of keywords separated by commas:
|
71 |
+
### Response:
|
72 |
+
"""
|
73 |
+
|
74 |
+
prompt_with_history = CustomPromptTemplate(
|
75 |
+
template=question_decompose_prompt,
|
76 |
+
tools=[],
|
77 |
+
input_variables=["input", "history"]
|
78 |
+
)
|
79 |
+
# %%
|
80 |
+
llm = CustomLLM(n=10)
|
81 |
+
question_decompose_chain = LLMChain(llm=llm, prompt=prompt_with_history)
|
82 |
+
|
83 |
+
question_decompose_agent = LLMSingleActionAgent(
|
84 |
+
llm_chain=question_decompose_chain,
|
85 |
+
output_parser=bare_output_parser,
|
86 |
+
stop=["\nObservation:"],
|
87 |
+
allowed_tools=[]
|
88 |
+
)
|
89 |
+
|
90 |
+
memory = ConversationBufferWindowMemory(k=10)
|
91 |
+
ax_1 = AgentExecutor.from_agent_and_tools(
|
92 |
+
agent=question_decompose_agent,
|
93 |
+
tools=[],
|
94 |
+
verbose=True,
|
95 |
+
memory=memory
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
def get_num_citations(pmid: str):
|
100 |
+
citations_xml = Entrez.read(
|
101 |
+
Entrez.elink(dbfrom="pubmed", db="pmc", LinkName="pubmed_pubmed_citedin", from_uid=pmid))
|
102 |
+
|
103 |
+
for i in range(0, len(citations_xml)):
|
104 |
+
if len(citations_xml[i]["LinkSetDb"]) > 0:
|
105 |
+
pmids_list = [link["Id"] for link in citations_xml[i]["LinkSetDb"][0]["Link"]]
|
106 |
+
return len(pmids_list)
|
107 |
+
else:
|
108 |
+
return 0
|
109 |
+
|
110 |
+
def fetch_pubmed_articles(keywords, max_search=10, max_context=3):
|
111 |
+
"""
|
112 |
+
The fetch_pubmed_articles function takes in a list of keywords and returns a list of articles.
|
113 |
+
The function uses the Entrez API to search for articles with the given keywords, then fetches
|
114 |
+
those articles from PubMed. The function returns a list of strings, where each string is an article.
|
115 |
+
|
116 |
+
:param keywords: Search for articles in the pubmed database
|
117 |
+
:param max_results: Specify the number of articles to be returned default is 1
|
118 |
+
:param email: Identify the user to ncbi
|
119 |
+
:return: A list of strings
|
120 |
+
"""
|
121 |
+
|
122 |
+
try:
|
123 |
+
search_result = Entrez.esearch(db="pubmed", term=keywords, retmax=max_search)
|
124 |
+
id_list = Entrez.read(search_result)["IdList"]
|
125 |
+
|
126 |
+
if len(id_list) == 0:
|
127 |
+
search_result = Entrez.esearch(db="pubmed", term=keywords[:4], retmax=max_search)
|
128 |
+
id_list = Entrez.read(search_result)["IdList"]
|
129 |
+
|
130 |
+
num_citations = [(id, get_num_citations(id)) for id in id_list]
|
131 |
+
top_n_papers = sorted(num_citations, key=lambda x: x[1], reverse=True)[:max_context]
|
132 |
+
print(f"top_{max_context}_papers: ", top_n_papers)
|
133 |
+
|
134 |
+
top_n_papers = [paper[0] for paper in top_n_papers]
|
135 |
+
fetch_handle = Entrez.efetch(db="pubmed", id=top_n_papers, rettype="medline", retmode="xml")
|
136 |
+
fetched_articles = Entrez.read(fetch_handle)
|
137 |
+
|
138 |
+
articles = []
|
139 |
+
# somehow only pull natural therapeutic articles
|
140 |
+
for fetched in fetched_articles['PubmedArticle']:
|
141 |
+
title = fetched['MedlineCitation']['Article']['ArticleTitle']
|
142 |
+
abstract = fetched['MedlineCitation']['Article']['Abstract']['AbstractText'][0] if 'Abstract' in fetched[
|
143 |
+
'MedlineCitation']['Article'] else "No Abstract"
|
144 |
+
# pmid = fetched['MedlineCitation']['PMID']
|
145 |
+
articles.append(title + "\n" + abstract)
|
146 |
+
|
147 |
+
return articles
|
148 |
+
except HTTPError as e:
|
149 |
+
print("HTTPError: ", e)
|
150 |
+
return []
|
151 |
+
except RuntimeError as e:
|
152 |
+
print("RuntimeError: ", e)
|
153 |
+
return []
|
154 |
+
|
155 |
+
|
156 |
+
def call_model_with_history(messages: list):
|
157 |
+
"""
|
158 |
+
The call_model_with_history function takes a list of messages and returns the next message in the conversation.
|
159 |
+
|
160 |
+
:param messages: list: Pass the history of messages to the model
|
161 |
+
:return: the text of the model's reply
|
162 |
+
"""
|
163 |
+
data = {
|
164 |
+
"messages": messages,
|
165 |
+
"stop": ["### Instruction:"], "temperature": 0, "max_tokens": 512, "stream": False
|
166 |
+
}
|
167 |
+
|
168 |
+
response = requests.post("https://5423-2605-7b80-3d-320-916c-64e7-c70b-e72d.ngrok-free.app/v1/chat/completions", headers={"Content-Type": "application/json"}, json=data)
|
169 |
+
return json.loads(response.text)['choices'][0]['message']['content']
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
# TODO: add ability to pass message history to model
|
174 |
+
def format_prompt_and_query(prompt, **kwargs):
|
175 |
+
"""
|
176 |
+
The format_prompt_and_query function takes a prompt and keyword arguments, formats the prompt with the keyword
|
177 |
+
arguments, and then calls call_model_with_history with a list of messages containing the formatted prompt.
|
178 |
+
|
179 |
+
:param prompt: Format the prompt with the values in kwargs
|
180 |
+
:param **kwargs: Pass a dictionary of key-value pairs to the formatting function
|
181 |
+
:return: A list of dictionaries
|
182 |
+
"""
|
183 |
+
|
184 |
+
formatted_prompt = prompt.format(**kwargs)
|
185 |
+
|
186 |
+
messages = [
|
187 |
+
{"role": "system", "content": "Perform the instructions to the best of your ability."},
|
188 |
+
{"role": "user", "content": formatted_prompt}
|
189 |
+
]
|
190 |
+
|
191 |
+
return call_model_with_history(messages)
|
192 |
+
|
193 |
+
|
194 |
+
class HerbalExpert:
|
195 |
+
def __init__(self, qd_chain):
|
196 |
+
self.qd_chain = qd_chain
|
197 |
+
self.wnl = WordNetLemmatizer()
|
198 |
+
self.default_questions = [
|
199 |
+
"How is chamomile traditionally used in herbal medicine?",
|
200 |
+
"What are the potential side effects or interactions of consuming echinacea?",
|
201 |
+
"Can you explain the different methods of consuming lavender for health benefits?",
|
202 |
+
"Which herbs are commonly known for their anti-inflammatory properties?",
|
203 |
+
"I'm experiencing consistent stress and anxiety. What herbs or supplements could help alleviate these symptoms?",
|
204 |
+
"Are there any natural herbs that could support better sleep?",
|
205 |
+
"What cannabis or hemp products would you recommend for chronic pain relief?",
|
206 |
+
"I'm looking to boost my immune system. Are there any specific herbs or supplements that could help?",
|
207 |
+
"Which herbs or supplements are recommended for enhancing cognitive functions and memory?"
|
208 |
+
]
|
209 |
+
# og = Original, qa = Question Asking, ri = Response Improvement
|
210 |
+
self.prompts = {
|
211 |
+
"og_answer_prompt": """### Instruction: Answer the following question using the given context. Question: {question}
|
212 |
+
Answer: ### Response: """,
|
213 |
+
|
214 |
+
"ans_decompose_prompt": """### Instruction: Given the following text, identify the 2 most important
|
215 |
+
keywords that capture the essence of the text. If there's a list of products, choose the top 2 products.
|
216 |
+
Your response should be a list of only 2 keywords separated by commas. Text: {original_answer} Keywords:
|
217 |
+
### Response: """,
|
218 |
+
|
219 |
+
"qa_prompt": """### Instruction: Answer the following question using the given context.
|
220 |
+
Question: {question}
|
221 |
+
Context: {context}
|
222 |
+
### Response: """,
|
223 |
+
|
224 |
+
"ri_prompt": """### Instruction: You are an caring, intelligent question answering agent. Craft a
|
225 |
+
response that is more informative and intelligent than the original answer and imparts knowledge from
|
226 |
+
both the old answer and from the context only if it helps answer the question.
|
227 |
+
Question: {question}
|
228 |
+
Old Answer: {answer}
|
229 |
+
Context: {answer2}
|
230 |
+
Improved answer: ### Response:"""
|
231 |
+
}
|
232 |
+
|
233 |
+
def process_query_words(self, question_words: str, answer_words: str):
|
234 |
+
# don't need to be searching for these in pubmed. Should we include: 'supplements', 'supplement'
|
235 |
+
vague_words = ['recommendation', 'recommendations', 'products', 'product']
|
236 |
+
words = question_words.lower().split(",") + answer_words.lower().split(",")
|
237 |
+
|
238 |
+
final_list = []
|
239 |
+
for word in words:
|
240 |
+
cleaned = word.strip().strip('"')
|
241 |
+
if cleaned not in vague_words:
|
242 |
+
final_list.append(self.wnl.lemmatize(cleaned))
|
243 |
+
|
244 |
+
return list(set(final_list))
|
245 |
+
|
246 |
+
def convert_question_into_words(self, question: str):
|
247 |
+
original_answer = format_prompt_and_query(self.prompts["og_answer_prompt"], question=question)
|
248 |
+
print("Original Answer: ", original_answer)
|
249 |
+
|
250 |
+
question_decompose = self.qd_chain.run(question)
|
251 |
+
print("Question Decompose: ", question_decompose)
|
252 |
+
|
253 |
+
original_answer_decompose = format_prompt_and_query(self.prompts["ans_decompose_prompt"],
|
254 |
+
original_answer=original_answer)
|
255 |
+
print("Original Answer Decomposed: ", original_answer_decompose)
|
256 |
+
|
257 |
+
words = self.process_query_words(question_decompose, original_answer_decompose)
|
258 |
+
return words, original_answer
|
259 |
+
|
260 |
+
def query_expert(self, question: str = None):
|
261 |
+
question = self.default_questions[
|
262 |
+
random.randint(0, len(self.default_questions) - 1)] if question is None else question
|
263 |
+
print("Question: ", question)
|
264 |
+
|
265 |
+
keywords, original_response = self.convert_question_into_words(question)
|
266 |
+
print("Keywords: ", keywords)
|
267 |
+
|
268 |
+
context = fetch_pubmed_articles(" AND ".join(keywords), max_search=5)
|
269 |
+
|
270 |
+
if len(context) == 0:
|
271 |
+
return {
|
272 |
+
"question": question,
|
273 |
+
"response": original_response,
|
274 |
+
"info": "No context found"
|
275 |
+
}
|
276 |
+
|
277 |
+
contextual_response = format_prompt_and_query(self.prompts["qa_prompt"], question=question, context=context)
|
278 |
+
improved_response = format_prompt_and_query(self.prompts["ri_prompt"], question=question,
|
279 |
+
answer=original_response, answer2=contextual_response)
|
280 |
+
|
281 |
+
return {
|
282 |
+
"question": question,
|
283 |
+
"response": improved_response,
|
284 |
+
"info": "Success"
|
285 |
+
}
|
286 |
+
|
287 |
+
|
288 |
+
herbal_expert = HerbalExpert(ax_1)
|
289 |
+
|
290 |
+
|
291 |
+
if __name__ == '__main__':
|
292 |
+
herbal_expert = HerbalExpert(ax_1)
|
293 |
+
answer = herbal_expert.query_expert("I'm experiencing consistent stress and anxiety. What herbs or supplements could help alleviate these symptoms?")
|
294 |
+
print(answer['response'])
|
295 |
+
# return to api? who knows
|