File size: 6,520 Bytes
5702d2c
 
 
 
 
 
 
 
 
 
 
 
cdf0df6
5702d2c
 
6703122
fa7038a
5702d2c
 
 
830002d
5702d2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
830002d
5702d2c
31b8e6f
 
830002d
31b8e6f
 
 
 
 
830002d
31b8e6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31738c7
31b8e6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31738c7
31b8e6f
 
 
 
 
 
 
 
 
 
 
 
 
9879508
 
 
31b8e6f
9879508
31b8e6f
 
 
67ee7a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31b8e6f
 
31738c7
 
fa7038a
 
 
 
 
 
 
 
 
 
 
31b8e6f
fa7038a
 
31738c7
740c3a1
31b8e6f
 
7a1e3f0
31b8e6f
31738c7
5702d2c
 
 
31b8e6f
 
cdf0df6
5702d2c
 
 
31b8e6f
 
cdf0df6
830002d
 
 
 
 
 
26cf6ec
830002d
5702d2c
830002d
 
5702d2c
830002d
9879508
 
830002d
9879508
 
 
5702d2c
830002d
5702d2c
 
 
830002d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import os
import gradio as gr
import time
import openai

# from langchain.llms import OpenAI
import pickle
from huggingface_hub import hf_hub_download
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.chains import RetrievalQA, LLMChain
from langchain.prompts import PromptTemplate
from langchain.embeddings import OpenAIEmbeddings


# embedding_file = "all_faiss_store_openai.pkl"
embedding_file ="year_public_store_openai.pkl"
with open(embedding_file, 'rb') as f:
  VectorStore = pickle.load(f)


""" initialize all the tools """

template = """
You are a knowledgeable assistant of Chartis' report and you are cautious about the answer you are giving. You will refuse to answer any questions that may generate an answer that violates the Open AI policy, or is not related to the given documents.
Given the user input question: {question}
• If the question can be inferred from the provided context, use the context to formulate your answer.
• If the question cannot be answered based on the context, simply state that you don't know. Do not provide inaccurate or made-up information.
Your answers should be:
• Direct and succinct.
• Accurate and directly addressing the user's questions.
{context}
Helpful Answer:"""


QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"], template=template,)

OPENAI_API_KEY = ''

# def slow_echo(usr_message, chat_history):
#     global OPENAI_API_KEY
    
#     # Check if the API key is set
#     if not OPENAI_API_KEY:
#         error_message = "OpenAI API key not set. Please provide the key first."
#         print(error_message)
#         return error_message, chat_history
    
#     try:
#         chat_model = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo', openai_api_key=OPENAI_API_KEY)

#         # customized memory
#         memory = ConversationBufferMemory(
#             return_messages=True,
#             output_key='result'
#         )
        
#         answer_chain = RetrievalQA.from_chain_type(
#             chat_model,
#             retriever=VectorStore.as_retriever(search_type="similarity", k=6),
#             memory = memory,
#             chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
#             return_source_documents=True
#         )
        
#         # Get a response from the OpenAI model
#         bot_result = answer_chain({"query": usr_message})
#         bot_response = bot_result['result']
#         source_doc = [bot_result['source_documents'][i].metadata['title'] for i in range(len(bot_result))]
#         source_page = [str(bot_result['source_documents'][i].metadata['page']+1) for i in range(len(bot_result))]

#         # formated output
#         source_print = {}
#         for i in range(len(source_doc)):
#           if source_doc[i] in source_print:
#             source_print[source_doc[i]] = source_print[source_doc[i]] + ', ' + source_page[i]
#           else:
#             source_print[source_doc[i]] = 'page: '+ source_page[i]

#         bot_response += '\n Source:'
#         for doc, page in source_print.items():
#             bot_response += '\n' + doc + ': ' + page

#         chat_history.append((usr_message, bot_response))
        
#         time.sleep(1)
#         yield "", chat_history

#     except openai.error.OpenAIError as e:
#         # Handle OpenAI-specific errors
#         error_message = f"An openAI API Error: {e}"
#         return error_message, chat_history

#     except Exception as e:
#         # Handle other unexpected errors
#         error_message = f"An unexpected error: {e}"
#         return error_message, chat_history

def get_opeanai_key(openai_key):
    global OPENAI_API_KEY
    OPENAI_API_KEY=openai_key

    return {chatbot_col: gr.Column(visible=True)}


def slow_echo(usr_message, chat_history):
    
    chat_model = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo', openai_api_key=OPENAI_API_KEY)
    # customized memory
    memory = ConversationBufferMemory(
          return_messages=True,
          output_key='result'
    )
    
    answer_chain = RetrievalQA.from_chain_type(
          chat_model,
          retriever=VectorStore.as_retriever(search_type="similarity", k=6),
          memory = memory,
          chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
          return_source_documents=True
    )
    try:
        # Attempt to get a response from the OpenAI model
        bot_result = answer_chain({"query": usr_message})
        bot_response = bot_result['result']
        source_doc = [bot_result['source_documents'][i].metadata['title'] for i in range(len(bot_result))]
        source_page = [str(bot_result['source_documents'][i].metadata['page']+1) for i in range(len(bot_result))]

        # formated output
        source_print = {}
        for i in range(len(source_doc)):
          if source_doc[i] in source_print:
            source_print[source_doc[i]] = source_print[source_doc[i]] + ', ' + source_page[i]
          else:
            source_print[source_doc[i]] = 'page: '+ source_page[i]

        bot_response = bot_response + '\n Source:'
        for doc, page in source_print.items():
            bot_response += '\n' + doc + ': ' + page

        chat_history.append((usr_message, bot_response))


        time.sleep(1)

        yield "", chat_history

    except openai.error.OpenAIError as e:
        # Handle OpenAI-specific errors
        error_message = f"OpenAI API Error: {e}"
        print(error_message)
        return error_message, chat_history

    except Exception as e:
        # Handle other unexpected errors
        error_message = f"Unexpected error: {e}"
        print(error_message)
        return error_message, chat_history


with gr.Blocks() as demo:
    gr.Markdown(
    """
    # Chartis Chatbot Demo
    Please provide your own GPT key below first and press submit to play with the chatbot!
    """)

    openai_gpt_key = gr.Textbox(label="OpenAI Key", value="", type="password", placeholder="sk-")
    btn = gr.Button(value="Submit")

    with gr.Column(visible=False) as chatbot_col:
        chatbot = gr.Chatbot()
        msg = gr.Textbox(label='Type in your questions about Chartis here and press Enter!',
                        placeholder='Type in your questions.', scale=7)
        clear = gr.ClearButton([msg, chatbot])
        
        msg.submit(slow_echo, [msg, chatbot], [msg, chatbot])

    btn.click(get_opeanai_key, inputs=[openai_gpt_key], outputs=[chatbot_col])



demo.queue().launch(debug=True)