File size: 6,801 Bytes
009313d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import gradio as gr

from langchain.chains import (
    ConversationalRetrievalChain,
    LLMChain,
    MapReduceDocumentsChain,
    ReduceDocumentsChain,
    StuffDocumentsChain,
)
from langchain.embeddings import OpenAIEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain_community.chat_models import ChatOpenAI
from langchain_community.document_loaders import WebBaseLoader


def wait_for_summarization(url):
    return [(None, f"Please wait while I summarize the contents of {url}...")]


def load_page(url, api_key, history):
    global docs, summary, llm
    loader = WebBaseLoader(url)
    docs = loader.load()
    llm = ChatOpenAI(
        model_name="gpt-3.5-turbo-1106", temperature=0, openai_api_key=api_key
    )
    map_template = """The following is a set of snippets from a web page:
{docs}
Based on this list of snippets, please identify the main themes
Helpful Answer:"""
    map_prompt = PromptTemplate.from_template(map_template)
    map_chain = LLMChain(llm=llm, prompt=map_prompt)

    # Reduce

    reduce_template = """The following is set of summaries of a web page:
{docs}
Take these and distill it into a final, consolidated summary of the main themes.
Helpful Answer:"""
    reduce_prompt = PromptTemplate.from_template(reduce_template)
    reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)

    # Takes a list of documents, combines them into a single string, and passes this to an LLMChain
    combine_documents_chain = StuffDocumentsChain(
        llm_chain=reduce_chain, document_variable_name="docs"
    )

    # Combines and iteratively reduces the mapped documents
    reduce_documents_chain = ReduceDocumentsChain(
        # This is final chain that is called.
        combine_documents_chain=combine_documents_chain,
        # If documents exceed context for `StuffDocumentsChain`
        collapse_documents_chain=combine_documents_chain,
        # The maximum number of tokens to group documents into.
        token_max=4000,
    )
    # Combining documents by mapping a chain over them, then combining results
    map_reduce_chain = MapReduceDocumentsChain(
        # Map chain
        llm_chain=map_chain,
        # Reduce chain
        reduce_documents_chain=reduce_documents_chain,
        # The variable name in the llm_chain to put the documents in
        document_variable_name="docs",
        # Return the results of the map steps in the output
        return_intermediate_steps=False,
    )

    text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
        chunk_size=1000, chunk_overlap=0
    )
    split_docs = text_splitter.split_documents(docs)

    summary = map_reduce_chain.run(split_docs)
    return history + [(None, summary)]


def prepare_chat(api_key, history):
    global docs, summary, llm, qa
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=128)
    documents = text_splitter.split_documents(docs)
    embeddings = OpenAIEmbeddings(openai_api_key=api_key)
    vectorstore = Chroma.from_documents(documents, embeddings)
    retriever = vectorstore.as_retriever(
        search_type="similarity", search_kwargs={"k": 6}
    )
    qa_prompt_template = (
        """As an AI assistant you help in answering questions about the contents of a web page.
The summary of the current web page is this:

"""
        + summary
        + """

Also, consider this additional context that may be relevant for the user's question:

{context}

Please answer following question: {question}"""
    )

    qa_prompt = PromptTemplate(
        template=qa_prompt_template, input_variables=["context", "question"]
    )

    memory = ConversationBufferMemory(
        memory_key="chat_history", return_messages=True, output_key="answer"
    )
    qa = ConversationalRetrievalChain.from_llm(
        llm=llm,
        memory=memory,
        retriever=retriever,
        combine_docs_chain_kwargs={"prompt": qa_prompt},
    )
    return history + [(None, "You can now ask me specific questions about the page.")]


def chatbot_function(message, history):
    global qa
    return "", history + [(message, qa.run(message))]


def build_demo():
    with gr.Blocks(theme=gr.themes.Default()) as demo:
        with gr.Row() as config_row:
            with gr.Column():
                api_key_box = gr.Textbox(
                    show_label=False,
                    placeholder="OpenAI API Key",
                    container=False,
                    autofocus=True,
                )
                url_box = gr.Textbox(
                    show_label=False,
                    placeholder="URL",
                    container=False,
                )
                load_btn = gr.Button(value="Load", variant="primary")
        with gr.Row(visible=False) as chat_row:
            with gr.Column():
                with gr.Row():
                    chatbot = gr.Chatbot(
                        elem_id="chatbot",
                        label="Web Chat",
                        height=550,
                    )
                with gr.Row(visible=False) as inputs_row:
                    with gr.Column(scale=8):
                        text_box = gr.Textbox(
                            show_label=False,
                            placeholder="Enter text and press ENTER",
                            autofocus=True,
                            container=False,
                        )
                    with gr.Column(scale=1, min_width=50):
                        submit_btn = gr.Button(
                            value="Send",
                            variant="primary",
                        )

        load_btn.click(
            lambda: gr.update(visible=False),
            outputs=[config_row],
        ).then(
            lambda: gr.update(visible=True),
            outputs=[chat_row],
        ).then(
            wait_for_summarization,
            inputs=[url_box],
            outputs=[chatbot],
        ).then(
            load_page,
            inputs=[url_box, api_key_box, chatbot],
            outputs=[chatbot],
        ).then(
            prepare_chat,
            inputs=[api_key_box, chatbot],
            outputs=[chatbot],
        ).then(
            lambda: gr.update(visible=True),
            outputs=[inputs_row],
        )

        text_box.submit(
            chatbot_function,
            [text_box, chatbot],
            [text_box, chatbot],
        )
        submit_btn.click(
            chatbot_function,
            [text_box, chatbot],
            [text_box, chatbot],
        )

    return demo


if __name__ == "__main__":
    demo = build_demo()
    demo.launch()