yahanyang777 commited on
Commit
830002d
·
1 Parent(s): 88f989f

passed the test on colab

Browse files
Files changed (1) hide show
  1. app.py +54 -53
app.py CHANGED
@@ -11,61 +11,64 @@ from langchain.memory import ConversationBufferMemory
11
  from langchain.chains import RetrievalQA, LLMChain
12
  from langchain.prompts import PromptTemplate
13
 
14
- OPENAI_API_KEY = ''
15
 
16
  embedding_file = "all_faiss_store_openai.pkl"
17
  with open(embedding_file, 'rb') as f:
18
  VectorStore = pickle.load(f)
19
 
 
20
  """ initialize all the tools """
21
 
22
  template = """
23
  You are a knowledgeable assistant of Chartis' report and you are cautious about the answer you are giving. You will refuse to answer any questions that may generate an answer that violates the Open AI policy, or is not related to the given documents.
24
  Given the user input question: {question}
25
-
26
  • If the question can be inferred from the provided context, use the context to formulate your answer.
27
  • If the question cannot be answered based on the context, simply state that you don't know. Do not provide inaccurate or made-up information.
28
-
29
  Your answers should be:
30
  • Direct and succinct.
31
  • Accurate and directly addressing the user's questions.
32
  {context}
33
-
34
  Helpful Answer:"""
35
 
36
 
37
  QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"], template=template,)
38
 
39
- chat_model = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo')
40
-
41
- # customized memory
42
- memory = ConversationBufferMemory(
43
- return_messages=True,
44
- output_key='result'
45
- )
46
-
47
- answer_chain = RetrievalQA.from_chain_type(
48
- chat_model,
49
- retriever=VectorStore.as_retriever(),
50
- memory = memory,
51
- chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
52
- return_source_documents=True
53
- )
54
-
55
-
56
 
57
  def slow_echo(usr_message, chat_history):
 
 
 
 
 
 
 
 
58
  try:
59
- # Attempt to get a response from the OpenAI model
60
- bot_result = answer_chain({"query": usr_message})
61
- bot_response = bot_result['result']
62
- source_page = bot_result['source_documents'][0].metadata['page']+1
 
 
 
63
 
64
- bot_response = bot_response +'source page:'+ str(source_page)
 
 
 
 
 
 
65
 
66
- time.sleep(1)
 
 
67
 
68
- yield bot_response
 
 
 
69
 
70
  except openai.error.OpenAIError as e:
71
  # Handle OpenAI-specific errors
@@ -77,35 +80,33 @@ def slow_echo(usr_message, chat_history):
77
  error_message = f"Unexpected error: {e}"
78
  print(error_message)
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- def get_key(api_key):
82
- if api_key!='':
83
- return "Key accepted! Displaying chatbot..."
84
- else:
85
- return "Invalid key. Please try again."
86
 
87
- key_interface = gr.Interface(
88
- fn=get_key,
89
- inputs=gr.Textbox(label="Enter your OpenAI key:"),
90
- outputs="text",
91
- live=True # This will call the function every time the input changes.
92
- )
93
 
 
94
 
95
- demo = gr.ChatInterface(
96
- slow_echo,
97
- chatbot=gr.Chatbot(height=500),
98
- textbox=gr.Textbox(label='Type in your questions about Chartis report here and press Enter!',
99
- placeholder='Type in your questions.', container=False, scale=7),
100
- title="Chartis Chatbot Demo",
101
- retry_btn=None,
102
- undo_btn=None,
103
- clear_btn="Clear"
104
- ).queue()
105
 
106
 
107
- if __name__ == "__main__":
108
- returned_data = key_interface.launch(return_data=True)
109
- OPENAI_API_KEY = returned_data[0]['data']['textbox']
110
 
111
- demo.launch(debug=True)
 
11
  from langchain.chains import RetrievalQA, LLMChain
12
  from langchain.prompts import PromptTemplate
13
 
 
14
 
15
  embedding_file = "all_faiss_store_openai.pkl"
16
  with open(embedding_file, 'rb') as f:
17
  VectorStore = pickle.load(f)
18
 
19
+
20
  """ initialize all the tools """
21
 
22
  template = """
23
  You are a knowledgeable assistant of Chartis' report and you are cautious about the answer you are giving. You will refuse to answer any questions that may generate an answer that violates the Open AI policy, or is not related to the given documents.
24
  Given the user input question: {question}
 
25
  • If the question can be inferred from the provided context, use the context to formulate your answer.
26
  • If the question cannot be answered based on the context, simply state that you don't know. Do not provide inaccurate or made-up information.
 
27
  Your answers should be:
28
  • Direct and succinct.
29
  • Accurate and directly addressing the user's questions.
30
  {context}
 
31
  Helpful Answer:"""
32
 
33
 
34
  QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"], template=template,)
35
 
36
+ OPENAI_API_KEY = ''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def slow_echo(usr_message, chat_history):
39
+ global OPENAI_API_KEY
40
+
41
+ # Check if the API key is set
42
+ if not OPENAI_API_KEY:
43
+ error_message = "OpenAI API key not set. Please provide the key first."
44
+ print(error_message)
45
+ return error_message, chat_history
46
+
47
  try:
48
+ chat_model = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo', openai_api_key=OPENAI_API_KEY)
49
+
50
+ # customized memory
51
+ memory = ConversationBufferMemory(
52
+ return_messages=True,
53
+ output_key='result'
54
+ )
55
 
56
+ answer_chain = RetrievalQA.from_chain_type(
57
+ chat_model,
58
+ retriever=VectorStore.as_retriever(),
59
+ memory = memory,
60
+ chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
61
+ return_source_documents=True
62
+ )
63
 
64
+ # Get a response from the OpenAI model
65
+ bot_result = answer_chain({"query": usr_message})
66
+ bot_response = bot_result['result']
67
 
68
+ chat_history.append((usr_message, bot_response))
69
+ time.sleep(1)
70
+
71
+ yield "", chat_history
72
 
73
  except openai.error.OpenAIError as e:
74
  # Handle OpenAI-specific errors
 
80
  error_message = f"Unexpected error: {e}"
81
  print(error_message)
82
 
83
+ def get_opeanai_key(openai_key):
84
+ global OPENAI_API_KEY
85
+ OPENAI_API_KEY=openai_key
86
+
87
+ return {chatbot_col: gr.Column(visible=True)}
88
+
89
+
90
+ with gr.Blocks() as demo:
91
+ gr.Markdown(
92
+ """
93
+ # Chartis Chatbot Demo
94
+ Please provide your own GPT key below first!
95
+ """)
96
 
97
+ openai_gpt_key = gr.Textbox(label="OpenAI Key", value="", type="password", placeholder="sk-")
98
+ btn = gr.Button(value="Submit")
 
 
 
99
 
100
+ with gr.Column(visible=False) as chatbot_col:
101
+ chatbot = gr.Chatbot()
102
+ msg = gr.Textbox(label='Type in your questions about Chartis here and press Enter!',
103
+ placeholder='Type in your questions.', scale=7)
104
+ clear = gr.ClearButton([msg, chatbot])
 
105
 
106
+ msg.submit(slow_echo, [msg, chatbot], [msg, chatbot])
107
 
108
+ btn.click(get_opeanai_key, inputs=[openai_gpt_key], outputs=[chatbot_col])
 
 
 
 
 
 
 
 
 
109
 
110
 
 
 
 
111
 
112
+ demo.queue().launch(debug=True)