|
import os |
|
import streamlit as st |
|
from PIL import Image, ImageOps |
|
from langchain_openai import ChatOpenAI |
|
from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings |
|
from langchain.vectorstores import FAISS |
|
from langchain.chains import RetrievalQA |
|
from langchain import PromptTemplate |
|
from langchain.retrievers import ContextualCompressionRetriever |
|
from langchain.retrievers.document_compressors import FlashrankRerank |
|
from dotenv import load_dotenv |
|
from langchain_community.embeddings.bedrock import BedrockEmbeddings |
|
load_dotenv() |
|
|
|
PDF_CHUNK_SIZE = 1024 |
|
PDF_CHUNK_OVERLAP = 256 |
|
k = 9 |
|
|
|
|
|
def load_and_pad_image(image_path, size=(64, 64)): |
|
img = Image.open(image_path) |
|
return ImageOps.pad(img, size) |
|
|
|
favicon_path = "medical.png" |
|
favicon_image = load_and_pad_image(favicon_path) |
|
|
|
|
|
st.set_page_config( |
|
page_title="Chatbot", |
|
page_icon=favicon_image, |
|
) |
|
|
|
|
|
col1, col2 = st.columns([1, 8]) |
|
with col1: |
|
st.image(favicon_image) |
|
with col2: |
|
st.markdown( |
|
""" |
|
<h1 style='text-align: left; margin-top: -12px;'>Chatbot</h1> |
|
""", unsafe_allow_html=True |
|
) |
|
|
|
|
|
model_options = ["gpt-4o", "gpt-4o-mini"] |
|
selected_model = st.selectbox("Choose a GPT model", model_options) |
|
|
|
embedding_model_options = ["OpenAI"] |
|
selected_embedding_model = st.selectbox("Choose an Embedding model", embedding_model_options) |
|
|
|
|
|
def get_llm(selected_model): |
|
api_key = os.getenv("DeepSeek_API_KEY") if selected_model == "deepseek-chat" else os.getenv("OPENAI_API_KEY") |
|
return ChatOpenAI( |
|
model=selected_model, |
|
temperature=0, |
|
max_tokens=None, |
|
api_key=api_key, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_vector_store(selected_embedding_model): |
|
if selected_embedding_model == "OpenAI": |
|
embeddings = OpenAIEmbeddings(model="text-embedding-3-large", api_key=os.getenv("OPENAI_API_KEY")) |
|
return FAISS.load_local("faiss_table", embeddings, allow_dangerous_deserialization=True) |
|
else: |
|
embeddings = HuggingFaceEmbeddings(model_name="abhinand/MedEmbed-large-v0.1") |
|
return FAISS.load_local("faiss_index_medical_MedEmbed", embeddings, allow_dangerous_deserialization=True) |
|
|
|
|
|
vector_store = load_vector_store(selected_embedding_model) |
|
llm = get_llm(selected_model) |
|
|
|
|
|
def main(): |
|
st.session_state['knowledge_base'] = vector_store |
|
st.header("Ask a Question") |
|
|
|
question = st.text_input("Enter your question") |
|
if st.button("Get Answer"): |
|
knowledge_base = st.session_state['knowledge_base'] |
|
retriever = knowledge_base.as_retriever(search_kwargs={"k": k}) |
|
compressor = FlashrankRerank() |
|
compression_retriever = ContextualCompressionRetriever( |
|
base_compressor=compressor, base_retriever=retriever |
|
) |
|
|
|
system_prompt = """ |
|
You are a friendly and knowledgeable assistant who is an expert in medical education who will only answer from the context provided. You need to understand the best context to answer the question. |
|
""" |
|
|
|
template = f""" |
|
{system_prompt} |
|
------------------------------- |
|
Context: {{context}} |
|
Question: {{question}} |
|
Answer: |
|
""" |
|
|
|
prompt = PromptTemplate( |
|
template=template, |
|
input_variables=['context', 'question'] |
|
) |
|
|
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm, |
|
retriever=compression_retriever, |
|
return_source_documents=True, |
|
chain_type_kwargs={"prompt": prompt} |
|
) |
|
|
|
response = qa_chain.invoke({"query": question}) |
|
st.write(f"**Answer:** {response['result']}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|