File size: 4,472 Bytes
57a9580 31348e8 57a9580 31348e8 57a9580 31348e8 57a9580 9ab3b40 57a9580 9ab3b40 57a9580 9ab3b40 57a9580 31348e8 57a9580 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import os
import streamlit as st
from PIL import Image, ImageOps
from langchain_openai import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain import PromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from dotenv import load_dotenv
from langchain_community.embeddings.bedrock import BedrockEmbeddings
load_dotenv()
# Hyperparameters
PDF_CHUNK_SIZE = 1024
PDF_CHUNK_OVERLAP = 256
k = 9
# Load favicon image
def load_and_pad_image(image_path, size=(64, 64)):
img = Image.open(image_path)
return ImageOps.pad(img, size)
favicon_path = "medical.png"
favicon_image = load_and_pad_image(favicon_path)
# Streamlit Page Config
st.set_page_config(
page_title="Chatbot",
page_icon=favicon_image,
)
# Set up logo and title
col1, col2 = st.columns([1, 8])
with col1:
st.image(favicon_image)
with col2:
st.markdown(
"""
<h1 style='text-align: left; margin-top: -12px;'>Chatbot</h1>
""", unsafe_allow_html=True
)
# Model and Embedding Selection
model_options = ["gpt-4o", "gpt-4o-mini"] #, "deepseek-chat"
selected_model = st.selectbox("Choose a GPT model", model_options)
embedding_model_options = ["OpenAI"] #, "Huggingface MedEmbed"
selected_embedding_model = st.selectbox("Choose an Embedding model", embedding_model_options)
# Load the model
def get_llm(selected_model):
api_key = os.getenv("DeepSeek_API_KEY") if selected_model == "deepseek-chat" else os.getenv("OPENAI_API_KEY")
return ChatOpenAI(
model=selected_model,
temperature=0,
max_tokens=None,
api_key=api_key,
)
# Cache the vector store loading
# @st.cache_resource
# def load_vector_store(selected_embedding_model):
# if selected_embedding_model == "OpenAI":
# embeddings = OpenAIEmbeddings(model="text-embedding-3-large", api_key=os.getenv("OPENAI_API_KEY"))
# return FAISS.load_local("faiss_index_medical_OpenAI", embeddings, allow_dangerous_deserialization=True)
# else:
# embeddings = HuggingFaceEmbeddings(model_name="abhinand/MedEmbed-large-v0.1")
# return FAISS.load_local("faiss_index_medical_MedEmbed", embeddings, allow_dangerous_deserialization=True)
@st.cache_resource
def load_vector_store(selected_embedding_model):
if selected_embedding_model == "OpenAI":
embeddings = OpenAIEmbeddings(model="text-embedding-3-large", api_key=os.getenv("OPENAI_API_KEY"))
return FAISS.load_local("faiss_table", embeddings, allow_dangerous_deserialization=True)
else:
embeddings = HuggingFaceEmbeddings(model_name="abhinand/MedEmbed-large-v0.1")
return FAISS.load_local("faiss_index_medical_MedEmbed", embeddings, allow_dangerous_deserialization=True)
# Load the selected vector store
vector_store = load_vector_store(selected_embedding_model)
llm = get_llm(selected_model)
# Main App Logic
def main():
st.session_state['knowledge_base'] = vector_store
st.header("Ask a Question")
question = st.text_input("Enter your question")
if st.button("Get Answer"):
knowledge_base = st.session_state['knowledge_base']
retriever = knowledge_base.as_retriever(search_kwargs={"k": k})
compressor = FlashrankRerank()
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
system_prompt = """
You are a friendly and knowledgeable assistant who is an expert in medical education who will only answer from the context provided. You need to understand the best context to answer the question.
"""
template = f"""
{system_prompt}
-------------------------------
Context: {{context}}
Question: {{question}}
Answer:
"""
prompt = PromptTemplate(
template=template,
input_variables=['context', 'question']
)
qa_chain = RetrievalQA.from_chain_type(
llm,
retriever=compression_retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": prompt}
)
response = qa_chain.invoke({"query": question})
st.write(f"**Answer:** {response['result']}")
if __name__ == "__main__":
main()
|