Spaces:
Paused
Paused
# -*- coding: utf-8 -*- | |
#!pip install gradio | |
#!pip install -U sentence-transformers | |
#!pip install langchain | |
#!pip install openai | |
#!pip install -U chromadb | |
import gradio as gr | |
from sentence_transformers import SentenceTransformer, CrossEncoder, util | |
from langchain.llms import OpenAI | |
from langchain.docstore.document import Document | |
from langchain.prompts import PromptTemplate | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain.chains.qa_with_sources import load_qa_with_sources_chain | |
from langchain import LLMMathChain, SQLDatabase, SQLDatabaseChain, LLMChain | |
from langchain.agents import initialize_agent, Tool | |
from langchain.agents import ZeroShotAgent, AgentExecutor | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain.schema import AIMessage, HumanMessage | |
import sqlite3 | |
import pandas as pd | |
import json | |
from functools import partial | |
import chromadb | |
import os | |
#cxn = sqlite3.connect('./data/mbr.db') | |
"""# import models""" | |
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') | |
bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens | |
#The bi-encoder will retrieve top_k documents. We use a cross-encoder, to re-rank the results list to improve the quality | |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
"""# setup vector db | |
- chromadb | |
- https://docs.trychroma.com/getting-started | |
""" | |
from chromadb.config import Settings | |
chroma_client = chromadb.Client(settings=Settings( | |
chroma_db_impl="duckdb+parquet", | |
persist_directory="./data/mychromadb/" # Optional, defaults to .chromadb/ in the current directory | |
)) | |
#!ls ./data/mychromadb/ | |
#collection = chroma_client.create_collection(name="benefit_collection") | |
collection = chroma_client.get_collection(name="plan_collection", embedding_function=bi_encoder) | |
faq_collection = chroma_client.get_collection(name="faq_collection", embedding_function=bi_encoder) | |
"""### vector db search examples""" | |
def rtrv(qry, collection, top_k=20): | |
results = collection.query( | |
query_embeddings=[ bi_encoder.encode(qry) ], | |
n_results=top_k, | |
) | |
return results | |
def vdb_src(qry, collection, src, top_k=20): | |
results = collection.query( | |
query_embeddings=[ bi_encoder.encode(qry) ], | |
n_results=top_k, | |
where={"source": src}, | |
) | |
return results | |
def vdb_where(qry, collection, where, top_k=20): | |
results = collection.query( | |
query_embeddings=[ bi_encoder.encode(qry) ], | |
n_results=top_k, | |
where=where, | |
) | |
return results | |
def vdb_pretty(qry, collection, top_k=10): | |
results = collection.query( | |
query_embeddings=[ bi_encoder.encode(qry) ], | |
n_results=top_k, | |
include=["metadatas", "documents", "distances","embeddings"] | |
) | |
rslt_pd = pd.DataFrame(results ).explode(['ids','documents', 'metadatas', 'distances', 'embeddings']) | |
rslt_fmt = pd.concat([rslt_pd.drop(['metadatas'], axis=1), rslt_pd['metadatas'].apply(pd.Series)], axis=1 ) | |
return rslt_fmt | |
# qry = 'Why should I chose Medicare Advantage over traditional Medicare?' | |
# rslt_fmt = vdb_pretty(qry, top_k=10) | |
# rslt_fmt | |
# doc_lst = rslt_fmt[['documents']].values.tolist() | |
# len(doc_lst) | |
"""# Introduction | |
- example of the kind of question answering that is possible with this tool | |
- assumes we are answering for a member with a Healthy Options Card | |
*When will I get my card?* | |
# semantic search functions | |
""" | |
# choosing to use rerank for this use case as a baseline | |
def rernk(query, collection=collection, where=None, top_k=20, top_n = 5): | |
rtrv_rslts = vdb_where(query, collection=collection, where=where, top_k=top_k) | |
rtrv_ids = rtrv_rslts.get('ids')[0] | |
rtrv_docs = rtrv_rslts.get('documents')[0] | |
##### Re-Ranking ##### | |
cross_inp = [[query, doc] for doc in rtrv_docs] | |
cross_scores = cross_encoder.predict(cross_inp) | |
# Sort results by the cross-encoder scores | |
combined = list(zip(rtrv_ids, list(cross_scores))) | |
sorted_tuples = sorted(combined, key=lambda x: x[1], reverse=True) | |
sorted_ids = [t[0] for t in sorted_tuples[:top_n]] | |
predictions = collection.get(ids=sorted_ids, include=["documents","metadatas"]) | |
return predictions | |
#return cross_scores | |
## version w/o re-rank | |
# def get_text_fmt(qry): | |
# prediction_text = [] | |
# predictions = rtrv(qry, top_k = 5) | |
# docs = predictions['documents'][0] | |
# meta = predictions['metadatas'][0] | |
# for i in range(len(docs)): | |
# result = Document(page_content=docs[i], metadata=meta[i]) | |
# prediction_text.append(result) | |
# return prediction_text | |
def get_text_fmt(qry, collection=collection, where=None): | |
prediction_text = [] | |
predictions = rernk(qry, collection=collection, where=where, top_k=20, top_n = 5) | |
docs = predictions['documents'] | |
meta = predictions['metadatas'] | |
for i in range(len(docs)): | |
result = Document(page_content=docs[i], metadata=meta[i]) | |
prediction_text.append(result) | |
return prediction_text | |
# get_text_fmt('why should I choose a medicare advantage plan over traditional medicare?') | |
"""# LLM based qa functions""" | |
llm = OpenAI(temperature=0) | |
# default model | |
# model_name: str = "text-davinci-003" | |
# instruction fine-tuned, sometimes referred to as GPT-3.5 | |
template = """You are a friendly AI assistant for the insurance company Humana. | |
Given the following extracted parts of a long document and a question, create a succinct final answer. | |
If you don't know the answer, just say that you don't know. Don't try to make up an answer. | |
If the question is not about Humana, politely inform the user that you are tuned to only answer questions about Humana. | |
QUESTION: {question} | |
========= | |
{summaries} | |
========= | |
FINAL ANSWER:""" | |
PROMPT = PromptTemplate(template=template, input_variables=["summaries", "question"]) | |
chain_qa = load_qa_with_sources_chain(llm=llm, chain_type="stuff", prompt=PROMPT, verbose=False) | |
def get_llm_response(message, collection=collection, where=None): | |
mydocs = get_text_fmt(message, collection, where) | |
responses = chain_qa({"input_documents":mydocs, "question":message}) | |
return responses | |
get_llm_response_humana = partial(get_llm_response, where={'company':'humana'}) | |
get_llm_response_essence = partial(get_llm_response, where={'company':'essence'}) | |
get_llm_response_faq = partial(get_llm_response, collection=faq_collection) | |
# rslt = get_llm_response('can I buy shrimp?') | |
# rslt['output_text'] | |
# for d in rslt['input_documents']: | |
# print(d.page_content) | |
# print(d.metadata['url']) | |
# rslt['output_text'] | |
"""# Database query""" | |
## setup member database | |
## only do this once | |
# d = {'mbr_fname':['bruce'], | |
# 'mbr_lname':['broussard'], | |
# 'mbr_id':[456] , | |
# 'policy_id':['H1036-236'], | |
# 'accumulated_out_of_pocket':[3800], | |
# 'accumulated_routine_footcare_visits':[6], | |
# 'accumulated_trasportation_trips':[22], | |
# 'accumulated_drug_cost':[7500], | |
# } | |
# df = pd.DataFrame(data=d, columns=['mbr_fname', 'mbr_lname', 'mbr_id', 'policy_id', 'accumulated_out_of_pocket', 'accumulated_routine_footcare_visits', 'accumulated_trasportation_trips','accumulated_drug_cost']) | |
# df.to_sql(name='mbr_details', con=cxn, if_exists='replace') | |
# # sample db query | |
# qry = '''select accumulated_routine_footcare_visits | |
# from mbr_details''' | |
# foot_det = pd.read_sql(qry, cxn) | |
# foot_det.values[0][0] | |
#db = SQLDatabase.from_uri("sqlite:///./data/mbr.db") | |
#db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True, return_intermediate_steps=True) | |
#def db_qry(qry): | |
# responses = db_chain('my mbr_id is 456 ;'+str(qry) ) ############### hardcode mbr id 456 for demo | |
# return responses | |
"""# Math | |
- default version | |
""" | |
llm_math_chain = LLMMathChain(llm=llm, verbose=True) | |
# llm_math_chain.run('what is the square root of 49?') | |
"""# Greeting""" | |
template = """You are an AI assistant for the insurance company Humana. | |
Your name is Jarvis and you were created on February 13, 2023. | |
Offer polite, friendly greetings and brief small talk. | |
Respond to thanks with, 'Glad to help.' | |
If the question is not about Humana, politely guide the user to ask questions about Humana insurance benefits | |
QUESTION: {question} | |
========= | |
FINAL ANSWER:""" | |
greet_prompt = PromptTemplate(template=template, input_variables=["question"]) | |
greet_llm = LLMChain(prompt=greet_prompt, llm=llm, verbose=True) | |
# greet_llm.run('will it snow in Lousiville tomorrow') | |
# greet_llm.run('Thanks, that was great') | |
"""# MRKL Chain""" | |
tools = [ | |
Tool( | |
name = "Humana Plans", | |
func=get_llm_response_humana, | |
description='''Useful for confirming benefits of Humana plans. | |
Useful for answering questions about Humana insurance plans. | |
You should ask targeted questions.''' | |
), | |
Tool( | |
name = "Essence Plans", | |
func=get_llm_response_essence, | |
description='''Useful for confirming benefits of Essence Healthcare plans. | |
Useful for answering questions about Essence Healthcare plans. | |
You should ask targeted questions.''' | |
), | |
Tool( | |
name = "FAQ", | |
func=get_llm_response_faq, | |
description='''Useful for answering general health insurance questions. Useful for answering questions about Medicare and | |
Medicare Advantage. ''' | |
), | |
Tool( | |
name="Calculator", | |
func=llm_math_chain.run, | |
description="""Only useful for when you need to answer questions about math, like subtracting two numbers or dividing numbers. | |
This tool should not be used to look up facts.""" | |
), | |
#Tool( | |
# name = "Search", | |
# func=search.run, | |
# description="Useful for when you need to answer questions than can not be answered using the other tools. This tool is a last resort." | |
#), | |
Tool( | |
name="Greeting", | |
func=greet_llm.run, | |
return_direct=True, | |
description="useful for when you need to respond to greetings, thanks, make small talk or answer questions about yourself" | |
), | |
] | |
##### Create Agent | |
#mrkl = initialize_agent(tools, llm, agent="zero-shot-react-description", verbose=False, return_intermediate_steps=True, max_iterations=5, early_stopping_method="generate") | |
prefix = """Answer the following question as best as you can. You should not make up any answers. To answer the question, use the following | |
tools:""" | |
suffix = """If the question is not about healthcare or Humana, | |
you should use the "Greeting" tool and pass it the question being asked. | |
If you are not confident in which tool to use, | |
you should use the "Greeting" tool and pass it the question being asked. | |
Remember, only answer using the information output from the | |
tools! Begin!" | |
{chat_history} | |
Question: {input} | |
{agent_scratchpad}""" | |
prompt = ZeroShotAgent.create_prompt( | |
tools, | |
prefix=prefix, | |
suffix=suffix, | |
input_variables=["input", "chat_history", "agent_scratchpad"] | |
) | |
llm_chain = LLMChain(llm=llm, prompt=prompt) | |
agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True) | |
agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, | |
max_iterations=5, early_stopping_method="generate", | |
return_intermediate_steps=True) | |
def make_memory_buffer(history, mem_len=2): | |
mem = ConversationBufferWindowMemory(k=mem_len, memory_key="chat_history", output_key="output") | |
hist = [] | |
for user,ai in history: | |
hist+=[HumanMessage(content=user), AIMessage(content=ai)] | |
mem.chat_memory.messages = hist | |
return mem | |
def agent_rspnd(qry, history, agent=agent_chain): | |
agent.memory = make_memory_buffer(history) | |
response = agent({"input":str(qry) }) | |
return response | |
def make_memory_buffer(history, mem_len=2): | |
hist = [] | |
for user,ai in history: | |
hist+=[HumanMessage(content=user), AIMessage(content=ai)] | |
mem = ConversationBufferWindowMemory(k=mem_len, memory_key="chat_history", output_key="output") | |
mem.chat_memory.messages = hist | |
return mem | |
def agent_rspnd(qry, history): | |
agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, | |
memory=make_memory_buffer(history), | |
max_iterations=5, early_stopping_method="generate", | |
return_intermediate_steps=True) | |
response = agent_chain({"input":str(qry) }) | |
return response | |
def mrkl_rspnd(qry): | |
response = mrkl({"input":str(qry) }) | |
return response | |
# r = mrkl_rspnd("can I buy fish with the card?") | |
# print(r['output']) | |
# print(json.dumps(r['intermediate_steps'], indent=2)) | |
#r['intermediate_steps'] | |
# from IPython.core.display import display, HTML | |
def get_cot(r): | |
cot = '<p>' | |
try: | |
intermedObj = r['intermediate_steps'] | |
cot +='<b>Input:</b> '+r['input']+'<br>' | |
for agnt_action, obs in intermedObj: | |
al = '<br> '.join(agnt_action.log.split('\n') ) | |
cot += '<b>AI chain of thought:</b> '+ al +'<br>' | |
if type(obs) is dict: | |
if obs.get('input_documents') is not None: #### this criteria doesn't work | |
for d in obs['input_documents']: | |
cot += ' '+'<i>- '+str(d.page_content)+'</i>'+' <a href="'+ str(d.metadata['url']) +'">'+str(d.metadata['page'])+'</a> '+'<br>' | |
cot += '<b>Observation:</b> '+str(obs['output_text']) +'<br><br>' | |
elif obs.get('intermediate_steps') is not None: | |
cot += '<b>Query:</b> '+str(obs.get('intermediate_steps')) +'<br><br>' | |
else: | |
pass | |
else: | |
cot += '<b>Observation:</b> '+str(obs) +'<br><br>' | |
except: | |
pass | |
cot += '</p>' | |
return cot | |
# cot = get_cot(r) | |
# display(HTML(cot)) | |
"""# chat example""" | |
def chat(message, history): | |
history = history or [] | |
#message = message.lower() | |
response = agent_rspnd(message, history) | |
cot = get_cot(response) | |
history.append((message, response['output'])) | |
return history, history, cot | |
css=".gradio-container {background-color: whitesmoke}" | |
xmpl_list = ["How does Humana's transportation benefit compare to Essence's?", | |
"Why should I choose a Medicare Advantage plan over Traditional Medicare?", | |
"What is the difference between a Medicare Advantage HMO plan and a PPO plan?", | |
"What is a low income subsidy plan and do I qualify for one of these plans?", | |
"Are my medications covered on a low income subsidy plan?"] | |
with gr.Blocks(css=css) as demo: | |
history_state = gr.State() | |
response_state = gr.State() | |
gr.Markdown('# Sales QA Bot') | |
gr.Markdown("""You are a **Louisville, KY** resident who currently has **Medicare Advantage** through an insurer called | |
**Essence Healthcare**. You don't know a lot about Medicare Advantage or your current benefits, so you may have questions about | |
how Humana's plans compare. This bot is here to help you learn about what **Humana has to offer** while answering any | |
other questions you might have. Welcome!""") | |
with gr.Row(): | |
chatbot = gr.Chatbot() | |
with gr.Accordion(label='Show AI chain of thought: ', open=False,): | |
ai_cot = gr.HTML(show_label=False) | |
with gr.Row(): | |
message = gr.Textbox(label='Input your question here:', | |
placeholder='Why should I choose Medicare Advantage?', | |
lines=1) | |
submit = gr.Button(value='Send', | |
variant='secondary').style(full_width=False) | |
submit.click(chat, | |
inputs=[message, history_state], | |
outputs=[chatbot, history_state, ai_cot]) | |
gr.Examples( | |
examples=xmpl_list, | |
inputs=message | |
) | |
demo.launch() | |