Adrian Cowham commited on
Commit
0e2eb99
·
1 Parent(s): 06150c8

modifed to accept chat history from client

Browse files
Files changed (1) hide show
  1. src/app.py +45 -2
src/app.py CHANGED
@@ -1,4 +1,4 @@
1
-
2
  import os
3
  from threading import Lock
4
  from typing import Any, Dict, Optional, Tuple
@@ -26,8 +26,10 @@ system_template = """
26
  The context below contains excerpts from 'Let's Talk,' by Andrea A. Lunsford. You must only use the information in the context below to formulate your response. If there is not enough information to formulate a response, you must respond with
27
  "I'm sorry, but I can't find the answer to your question in, the book Let's Talk..."
28
 
29
- Here is the context:
30
  {context}
 
 
31
  {chat_history}
32
  """
33
 
@@ -60,6 +62,44 @@ def getretriever():
60
 
61
  retriever = getretriever()
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def getanswer(chain, question, history):
64
  if hasattr(chain, "value"):
65
  chain = chain.value
@@ -135,4 +175,7 @@ with gr.Blocks() as block:
135
  ex5 = gr.Button(value="How do I cite a Reddit thread?", variant="primary")
136
  ex5.click(getanswer, inputs=[chain_state, ex5, state], outputs=[chatbot, state, message])
137
 
 
 
 
138
  block.launch(debug=True)
 
1
+ import json
2
  import os
3
  from threading import Lock
4
  from typing import Any, Dict, Optional, Tuple
 
26
  The context below contains excerpts from 'Let's Talk,' by Andrea A. Lunsford. You must only use the information in the context below to formulate your response. If there is not enough information to formulate a response, you must respond with
27
  "I'm sorry, but I can't find the answer to your question in, the book Let's Talk..."
28
 
29
+ Begin context:
30
  {context}
31
+ End context.
32
+
33
  {chat_history}
34
  """
35
 
 
62
 
63
  retriever = getretriever()
64
 
65
+ def predict(message):
66
+ print(message)
67
+ msgJson = json.loads(message)
68
+ print(msgJson)
69
+ messages = [
70
+ SystemMessagePromptTemplate.from_template(system_template),
71
+ HumanMessagePromptTemplate.from_template("{question}")
72
+ ]
73
+ qa_prompt = ChatPromptTemplate.from_messages(messages)
74
+
75
+ llm = ChatOpenAI(
76
+ openai_api_key=API_KEY,
77
+ model_name=MODEL,
78
+ verbose=True)
79
+ memory = AnswerConversationBufferMemory(memory_key="chat_history", return_messages=True)
80
+ for msg in msgJson["history"]:
81
+ memory.save_context({"input": msg[0]}, {"answer": msg[1]})
82
+
83
+ chain = ConversationalRetrievalChain.from_llm(
84
+ llm,
85
+ retriever=retriever,
86
+ return_source_documents=USE_VERBOSE,
87
+ memory=memory,
88
+ verbose=USE_VERBOSE,
89
+ combine_docs_chain_kwargs={"prompt": qa_prompt})
90
+ chain.rephrase_question = False
91
+ lock = Lock()
92
+ lock.acquire()
93
+ try:
94
+ output = chain({"question": msgJson["question"]})
95
+ output = output["answer"]
96
+ except Exception as e:
97
+ print(e)
98
+ raise e
99
+ finally:
100
+ lock.release()
101
+ return output
102
+
103
  def getanswer(chain, question, history):
104
  if hasattr(chain, "value"):
105
  chain = chain.value
 
175
  ex5 = gr.Button(value="How do I cite a Reddit thread?", variant="primary")
176
  ex5.click(getanswer, inputs=[chain_state, ex5, state], outputs=[chatbot, state, message])
177
 
178
+ predictBtn = gr.Button(value="Predict", visible=False)
179
+ predictBtn.click(predict, inputs=[message], outputs=[message])
180
+
181
  block.launch(debug=True)