File size: 3,422 Bytes
60e8923
 
 
 
 
ecb7a48
60e8923
 
 
ecb7a48
 
 
60e8923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ecb7a48
60e8923
 
 
 
 
 
 
 
 
 
 
ecb7a48
 
 
 
60e8923
 
 
 
 
 
 
 
 
ecb7a48
60e8923
 
 
ecb7a48
60e8923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ecb7a48
 
60e8923
 
 
 
 
 
 
ecb7a48
 
 
60e8923
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
from pathlib import Path

from langchain.chains import ConversationalRetrievalChain
from langchain.vectorstores import Chroma
from langchain.llms.openai import OpenAIChat, OpenAI
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain_experimental.text_splitter import SemanticChunker

import streamlit as st


LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath("vector_store")


def load_documents():
    loaders = [
        PyPDFLoader(source_doc_url)
        if source_doc_url.endswith(".pdf")
        else WebBaseLoader(source_doc_url)
        for source_doc_url in st.session_state.source_doc_urls
    ]
    documents = []
    for loader in loaders:
        documents.extend(loader.load())
    return documents


def split_documents(documents):
    text_splitter = SemanticChunker(OpenAIEmbeddings())
    texts = text_splitter.split_documents(documents)
    return texts


def embeddings_on_local_vectordb(texts):
    vectordb = Chroma.from_documents(
        texts,
        embedding=OpenAIEmbeddings(),
        persist_directory=LOCAL_VECTOR_STORE_DIR.as_posix(),
    )
    vectordb.persist()
    retriever = ContextualCompressionRetriever(
        base_compressor=LLMChainExtractor.from_llm(OpenAI(temperature=0)),
        base_retriever=vectordb.as_retriever(search_kwargs={"k": 3}, search_type="mmr"),
    )
    return retriever


def query_llm(retriever, query):
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm=OpenAIChat(),
        retriever=retriever,
        return_source_documents=True,
    )
    relevant_docs = retriever.get_relevant_documents(query)
    result = qa_chain({"question": query, "chat_history": st.session_state.messages})
    result = result["answer"]
    st.session_state.messages.append((query, result))
    return relevant_docs, result


def input_fields():
    os.environ["OPENAI_API_KEY"] = "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"
    st.session_state.source_doc_urls = [
        url.strip() for url in st.sidebar.text_input("Source Document URLs").split(",")
    ]


def process_documents():
    try:
        documents = load_documents()
        texts = split_documents(documents)
        st.session_state.retriever = embeddings_on_local_vectordb(texts)
    except Exception as e:
        st.error(f"An error occurred: {e}")


def boot():
    st.title("Enigma Chatbot")
    input_fields()
    st.sidebar.button("Submit Documents", on_click=process_documents)
    st.sidebar.write("---")
    st.sidebar.write("References made during the chat will appear here")
    if "messages" not in st.session_state:
        st.session_state.messages = []
    for message in st.session_state.messages:
        st.chat_message("human").write(message[0])
        st.chat_message("ai").write(message[1])
    if query := st.chat_input():
        st.chat_message("human").write(query)
        references, response = query_llm(st.session_state.retriever, query)
        for doc in references:
            st.sidebar.info(f"Page {doc.metadata['page']}\n\n{doc.page_content}")
        st.chat_message("ai").write(response)


if __name__ == "__main__":
    boot()