Spaces:
Sleeping
Sleeping
import requests | |
import json | |
import random | |
from langchain.agents import AgentExecutor, LLMSingleActionAgent, AgentOutputParser | |
from langchain.prompts import StringPromptTemplate | |
from langchain.schema import AgentAction, AgentFinish | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain import LLMChain | |
from langchain.llms.base import LLM | |
from Bio import Entrez | |
from requests import HTTPError | |
from nltk.stem import WordNetLemmatizer | |
Entrez.email = "[email protected]" | |
from langchain.callbacks.manager import CallbackManagerForLLMRun | |
from typing import List, Union, Optional, Any | |
ngrok_url = 'https://2590-2605-7b80-3d-320-a515-4f0d-f60e-71e5.ngrok-free.app/' | |
class CustomLLM(LLM): | |
n: int | |
def _llm_type(self) -> str: | |
return "custom" | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
data = { | |
"messages": [ | |
{ | |
"role": "user", | |
"content": prompt | |
} | |
], | |
"stop": ["### Instruction:"], "temperature": 0, "max_tokens": 512, "stream": False | |
} | |
response = requests.post(ngrok_url + "v1/chat/completions", | |
headers={"Content-Type": "application/json"}, json=data) | |
return json.loads(response.text)['choices'][0]['message']['content'] | |
# return make_inference_call(prompt) | |
class CustomPromptTemplate(StringPromptTemplate): | |
template: str | |
def format(self, **kwargs) -> str: | |
return self.template.format(**kwargs) | |
class CustomOutputParser(AgentOutputParser): | |
def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]: | |
return AgentFinish(return_values={"output": llm_output}, log=llm_output) | |
bare_output_parser = CustomOutputParser() | |
question_decompose_prompt = """ | |
### Instruction: Given the previous conversation history and the current question, pick out the relevant keywords from the question that would be used to search a medical article database. | |
Chat History: {history} | |
Question: {input} | |
Your response should be a list of keywords separated by commas: | |
### Response: | |
""" | |
prompt_with_history = CustomPromptTemplate( | |
template=question_decompose_prompt, | |
tools=[], | |
input_variables=["input", "history"] | |
) | |
# %% | |
llm = CustomLLM(n=10) | |
question_decompose_chain = LLMChain(llm=llm, prompt=prompt_with_history) | |
question_decompose_agent = LLMSingleActionAgent( | |
llm_chain=question_decompose_chain, | |
output_parser=bare_output_parser, | |
stop=["\nObservation:"], | |
allowed_tools=[] | |
) | |
memory = ConversationBufferWindowMemory(k=10) | |
ax_1 = AgentExecutor.from_agent_and_tools( | |
agent=question_decompose_agent, | |
tools=[], | |
verbose=True, | |
memory=memory | |
) | |
def get_num_citations(pmid: str): | |
citations_xml = Entrez.read( | |
Entrez.elink(dbfrom="pubmed", db="pmc", LinkName="pubmed_pubmed_citedin", from_uid=pmid)) | |
for i in range(0, len(citations_xml)): | |
if len(citations_xml[i]["LinkSetDb"]) > 0: | |
pmids_list = [link["Id"] for link in citations_xml[i]["LinkSetDb"][0]["Link"]] | |
return len(pmids_list) | |
else: | |
return 0 | |
def fetch_pubmed_articles(keywords, max_search=10, max_context=3): | |
""" | |
The fetch_pubmed_articles function takes in a list of keywords and returns a list of articles. | |
The function uses the Entrez API to search for articles with the given keywords, then fetches | |
those articles from PubMed. The function returns a list of strings, where each string is an article. | |
:param keywords: Search for articles in the pubmed database | |
:param max_results: Specify the number of articles to be returned default is 1 | |
:param email: Identify the user to ncbi | |
:return: A list of strings | |
""" | |
try: | |
search_result = Entrez.esearch(db="pubmed", term=keywords, retmax=max_search) | |
id_list = Entrez.read(search_result)["IdList"] | |
if len(id_list) == 0: | |
search_result = Entrez.esearch(db="pubmed", term=keywords[:4], retmax=max_search) | |
id_list = Entrez.read(search_result)["IdList"] | |
num_citations = [(id, get_num_citations(id)) for id in id_list] | |
top_n_papers = sorted(num_citations, key=lambda x: x[1], reverse=True)[:max_context] | |
print(f"top_{max_context}_papers: ", top_n_papers) | |
top_n_papers = [paper[0] for paper in top_n_papers] | |
fetch_handle = Entrez.efetch(db="pubmed", id=top_n_papers, rettype="medline", retmode="xml") | |
fetched_articles = Entrez.read(fetch_handle) | |
articles = [] | |
# somehow only pull natural therapeutic articles | |
for fetched in fetched_articles['PubmedArticle']: | |
title = fetched['MedlineCitation']['Article']['ArticleTitle'] | |
abstract = fetched['MedlineCitation']['Article']['Abstract']['AbstractText'][0] if 'Abstract' in fetched[ | |
'MedlineCitation']['Article'] else "No Abstract" | |
# pmid = fetched['MedlineCitation']['PMID'] | |
articles.append(title + "\n" + abstract) | |
return articles | |
except HTTPError as e: | |
print("HTTPError: ", e) | |
return [] | |
except RuntimeError as e: | |
print("RuntimeError: ", e) | |
return [] | |
def call_model_with_history(messages: list): | |
""" | |
The call_model_with_history function takes a list of messages and returns the next message in the conversation. | |
:param messages: list: Pass the history of messages to the model | |
:return: the text of the model's reply | |
""" | |
data = { | |
"messages": messages, | |
"stop": ["### Instruction:"], "temperature": 0, "max_tokens": 512, "stream": False | |
} | |
response = requests.post(ngrok_url+ "v1/chat/completions", headers={"Content-Type": "application/json"}, json=data) | |
return json.loads(response.text)['choices'][0]['message']['content'] | |
# TODO: add ability to pass message history to model | |
def format_prompt_and_query(prompt, **kwargs): | |
""" | |
The format_prompt_and_query function takes a prompt and keyword arguments, formats the prompt with the keyword | |
arguments, and then calls call_model_with_history with a list of messages containing the formatted prompt. | |
:param prompt: Format the prompt with the values in kwargs | |
:param **kwargs: Pass a dictionary of key-value pairs to the formatting function | |
:return: A list of dictionaries | |
""" | |
formatted_prompt = prompt.format(**kwargs) | |
messages = [ | |
{"role": "system", "content": "Perform the instructions to the best of your ability."}, | |
{"role": "user", "content": formatted_prompt} | |
] | |
return call_model_with_history(messages) | |
class HerbalExpert: | |
def __init__(self, qd_chain): | |
self.qd_chain = qd_chain | |
self.wnl = WordNetLemmatizer() | |
self.default_questions = [ | |
"How is chamomile traditionally used in herbal medicine?", | |
"What are the potential side effects or interactions of consuming echinacea?", | |
"Can you explain the different methods of consuming lavender for health benefits?", | |
"Which herbs are commonly known for their anti-inflammatory properties?", | |
"I'm experiencing consistent stress and anxiety. What herbs or supplements could help alleviate these symptoms?", | |
"Are there any natural herbs that could support better sleep?", | |
"What cannabis or hemp products would you recommend for chronic pain relief?", | |
"I'm looking to boost my immune system. Are there any specific herbs or supplements that could help?", | |
"Which herbs or supplements are recommended for enhancing cognitive functions and memory?" | |
] | |
# og = Original, qa = Question Asking, ri = Response Improvement | |
self.prompts = { | |
"og_answer_prompt": """### Instruction: Answer the following question using the given context. Question: {question} | |
Answer: ### Response: """, | |
"ans_decompose_prompt": """### Instruction: Given the following text, identify the 2 most important | |
keywords that capture the essence of the text. If there's a list of products, choose the top 2 products. | |
Your response should be a list of only 2 keywords separated by commas. Text: {original_answer} Keywords: | |
### Response: """, | |
"qa_prompt": """### Instruction: Answer the following question using the given context. | |
Question: {question} | |
Context: {context} | |
### Response: """, | |
"ri_prompt": """### Instruction: You are an caring, intelligent question answering agent. Craft a | |
response that is more informative and intelligent than the original answer and imparts knowledge from | |
both the old answer and from the context only if it helps answer the question. | |
Question: {question} | |
Old Answer: {answer} | |
Context: {answer2} | |
Improved answer: ### Response:""" | |
} | |
def process_query_words(self, question_words: str, answer_words: str): | |
# don't need to be searching for these in pubmed. Should we include: 'supplements', 'supplement' | |
vague_words = ['recommendation', 'recommendations', 'products', 'product'] | |
words = question_words.lower().split(",") + answer_words.lower().split(",") | |
final_list = [] | |
for word in words: | |
cleaned = word.strip().strip('"') | |
if cleaned not in vague_words: | |
final_list.append(self.wnl.lemmatize(cleaned)) | |
return list(set(final_list)) | |
def convert_question_into_words(self, question: str): | |
original_answer = format_prompt_and_query(self.prompts["og_answer_prompt"], question=question) | |
print("Original Answer: ", original_answer) | |
question_decompose = self.qd_chain.run(question) | |
print("Question Decompose: ", question_decompose) | |
original_answer_decompose = format_prompt_and_query(self.prompts["ans_decompose_prompt"], | |
original_answer=original_answer) | |
print("Original Answer Decomposed: ", original_answer_decompose) | |
words = self.process_query_words(question_decompose, original_answer_decompose) | |
return words, original_answer | |
def query_expert(self, question: str = None): | |
question = self.default_questions[ | |
random.randint(0, len(self.default_questions) - 1)] if question is None else question | |
print("Question: ", question) | |
keywords, original_response = self.convert_question_into_words(question) | |
print("Keywords: ", keywords) | |
context = fetch_pubmed_articles(" AND ".join(keywords), max_search=5) | |
if len(context) == 0: | |
return { | |
"question": question, | |
"response": original_response, | |
"info": "No context found" | |
} | |
contextual_response = format_prompt_and_query(self.prompts["qa_prompt"], question=question, context=context) | |
improved_response = format_prompt_and_query(self.prompts["ri_prompt"], question=question, | |
answer=original_response, answer2=contextual_response) | |
return { | |
"question": question, | |
"response": improved_response, | |
"info": "Success" | |
} | |
herbal_expert = HerbalExpert(ax_1) | |
if __name__ == '__main__': | |
herbal_expert = HerbalExpert(ax_1) | |
answer = herbal_expert.query_expert("I'm experiencing consistent stress and anxiety. What herbs or supplements could help alleviate these symptoms?") | |
print(answer['response']) | |
# return to api? who knows | |