File size: 4,472 Bytes
57a9580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31348e8
57a9580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31348e8
57a9580
 
31348e8
57a9580
 
 
 
 
 
 
 
 
 
 
 
 
9ab3b40
 
 
 
 
 
 
 
 
 
57a9580
 
 
 
9ab3b40
57a9580
 
 
9ab3b40
57a9580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31348e8
57a9580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import streamlit as st
from PIL import Image, ImageOps
from langchain_openai import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain import PromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from dotenv import load_dotenv
from langchain_community.embeddings.bedrock import BedrockEmbeddings
load_dotenv()
# Hyperparameters
PDF_CHUNK_SIZE = 1024
PDF_CHUNK_OVERLAP = 256
k = 9

# Load favicon image
def load_and_pad_image(image_path, size=(64, 64)):
    img = Image.open(image_path)
    return ImageOps.pad(img, size)

favicon_path = "medical.png"
favicon_image = load_and_pad_image(favicon_path)

# Streamlit Page Config
st.set_page_config(
    page_title="Chatbot",
    page_icon=favicon_image,
)

# Set up logo and title
col1, col2 = st.columns([1, 8])
with col1:
    st.image(favicon_image)
with col2:
    st.markdown(
        """
        <h1 style='text-align: left; margin-top: -12px;'>Chatbot</h1>
        """, unsafe_allow_html=True
    )

# Model and Embedding Selection
model_options = ["gpt-4o", "gpt-4o-mini"] #, "deepseek-chat"
selected_model = st.selectbox("Choose a GPT model", model_options)

embedding_model_options = ["OpenAI"] #, "Huggingface MedEmbed"
selected_embedding_model = st.selectbox("Choose an Embedding model", embedding_model_options)

# Load the model
def get_llm(selected_model):
    api_key = os.getenv("DeepSeek_API_KEY") if selected_model == "deepseek-chat" else os.getenv("OPENAI_API_KEY")
    return ChatOpenAI(
        model=selected_model,
        temperature=0,
        max_tokens=None,
        api_key=api_key,
    )

# Cache the vector store loading
# @st.cache_resource
# def load_vector_store(selected_embedding_model):
#     if selected_embedding_model == "OpenAI":
#         embeddings = OpenAIEmbeddings(model="text-embedding-3-large", api_key=os.getenv("OPENAI_API_KEY"))
#         return FAISS.load_local("faiss_index_medical_OpenAI", embeddings, allow_dangerous_deserialization=True)
#     else:
#         embeddings = HuggingFaceEmbeddings(model_name="abhinand/MedEmbed-large-v0.1")
#         return FAISS.load_local("faiss_index_medical_MedEmbed", embeddings, allow_dangerous_deserialization=True)


@st.cache_resource
def load_vector_store(selected_embedding_model):
    if selected_embedding_model == "OpenAI":
        embeddings = OpenAIEmbeddings(model="text-embedding-3-large", api_key=os.getenv("OPENAI_API_KEY"))
        return FAISS.load_local("faiss_table", embeddings, allow_dangerous_deserialization=True)
    else:
        embeddings = HuggingFaceEmbeddings(model_name="abhinand/MedEmbed-large-v0.1")
        return FAISS.load_local("faiss_index_medical_MedEmbed", embeddings, allow_dangerous_deserialization=True)
        
# Load the selected vector store
vector_store = load_vector_store(selected_embedding_model)
llm = get_llm(selected_model)

# Main App Logic
def main():
    st.session_state['knowledge_base'] = vector_store
    st.header("Ask a Question")

    question = st.text_input("Enter your question")
    if st.button("Get Answer"):
        knowledge_base = st.session_state['knowledge_base']
        retriever = knowledge_base.as_retriever(search_kwargs={"k": k})
        compressor = FlashrankRerank()
        compression_retriever = ContextualCompressionRetriever(
            base_compressor=compressor, base_retriever=retriever
        )

        system_prompt = """
        You are a friendly and knowledgeable assistant who is an expert in medical education who will only answer from the context provided. You need to understand the best context to answer the question.
        """

        template = f"""
        {system_prompt}
        -------------------------------
        Context: {{context}}
        Question: {{question}}
        Answer:
        """

        prompt = PromptTemplate(
            template=template,
            input_variables=['context', 'question']
        )

        qa_chain = RetrievalQA.from_chain_type(
            llm,
            retriever=compression_retriever,
            return_source_documents=True,
            chain_type_kwargs={"prompt": prompt}
        )

        response = qa_chain.invoke({"query": question})
        st.write(f"**Answer:** {response['result']}")

if __name__ == "__main__":
    main()