Spaces:
Runtime error
Runtime error
Adrian Cowham
commited on
Commit
·
0e2eb99
1
Parent(s):
06150c8
modifed to accept chat history from client
Browse files- src/app.py +45 -2
src/app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
|
2 |
import os
|
3 |
from threading import Lock
|
4 |
from typing import Any, Dict, Optional, Tuple
|
@@ -26,8 +26,10 @@ system_template = """
|
|
26 |
The context below contains excerpts from 'Let's Talk,' by Andrea A. Lunsford. You must only use the information in the context below to formulate your response. If there is not enough information to formulate a response, you must respond with
|
27 |
"I'm sorry, but I can't find the answer to your question in, the book Let's Talk..."
|
28 |
|
29 |
-
|
30 |
{context}
|
|
|
|
|
31 |
{chat_history}
|
32 |
"""
|
33 |
|
@@ -60,6 +62,44 @@ def getretriever():
|
|
60 |
|
61 |
retriever = getretriever()
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
def getanswer(chain, question, history):
|
64 |
if hasattr(chain, "value"):
|
65 |
chain = chain.value
|
@@ -135,4 +175,7 @@ with gr.Blocks() as block:
|
|
135 |
ex5 = gr.Button(value="How do I cite a Reddit thread?", variant="primary")
|
136 |
ex5.click(getanswer, inputs=[chain_state, ex5, state], outputs=[chatbot, state, message])
|
137 |
|
|
|
|
|
|
|
138 |
block.launch(debug=True)
|
|
|
1 |
+
import json
|
2 |
import os
|
3 |
from threading import Lock
|
4 |
from typing import Any, Dict, Optional, Tuple
|
|
|
26 |
The context below contains excerpts from 'Let's Talk,' by Andrea A. Lunsford. You must only use the information in the context below to formulate your response. If there is not enough information to formulate a response, you must respond with
|
27 |
"I'm sorry, but I can't find the answer to your question in, the book Let's Talk..."
|
28 |
|
29 |
+
Begin context:
|
30 |
{context}
|
31 |
+
End context.
|
32 |
+
|
33 |
{chat_history}
|
34 |
"""
|
35 |
|
|
|
62 |
|
63 |
retriever = getretriever()
|
64 |
|
65 |
+
def predict(message):
|
66 |
+
print(message)
|
67 |
+
msgJson = json.loads(message)
|
68 |
+
print(msgJson)
|
69 |
+
messages = [
|
70 |
+
SystemMessagePromptTemplate.from_template(system_template),
|
71 |
+
HumanMessagePromptTemplate.from_template("{question}")
|
72 |
+
]
|
73 |
+
qa_prompt = ChatPromptTemplate.from_messages(messages)
|
74 |
+
|
75 |
+
llm = ChatOpenAI(
|
76 |
+
openai_api_key=API_KEY,
|
77 |
+
model_name=MODEL,
|
78 |
+
verbose=True)
|
79 |
+
memory = AnswerConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
80 |
+
for msg in msgJson["history"]:
|
81 |
+
memory.save_context({"input": msg[0]}, {"answer": msg[1]})
|
82 |
+
|
83 |
+
chain = ConversationalRetrievalChain.from_llm(
|
84 |
+
llm,
|
85 |
+
retriever=retriever,
|
86 |
+
return_source_documents=USE_VERBOSE,
|
87 |
+
memory=memory,
|
88 |
+
verbose=USE_VERBOSE,
|
89 |
+
combine_docs_chain_kwargs={"prompt": qa_prompt})
|
90 |
+
chain.rephrase_question = False
|
91 |
+
lock = Lock()
|
92 |
+
lock.acquire()
|
93 |
+
try:
|
94 |
+
output = chain({"question": msgJson["question"]})
|
95 |
+
output = output["answer"]
|
96 |
+
except Exception as e:
|
97 |
+
print(e)
|
98 |
+
raise e
|
99 |
+
finally:
|
100 |
+
lock.release()
|
101 |
+
return output
|
102 |
+
|
103 |
def getanswer(chain, question, history):
|
104 |
if hasattr(chain, "value"):
|
105 |
chain = chain.value
|
|
|
175 |
ex5 = gr.Button(value="How do I cite a Reddit thread?", variant="primary")
|
176 |
ex5.click(getanswer, inputs=[chain_state, ex5, state], outputs=[chatbot, state, message])
|
177 |
|
178 |
+
predictBtn = gr.Button(value="Predict", visible=False)
|
179 |
+
predictBtn.click(predict, inputs=[message], outputs=[message])
|
180 |
+
|
181 |
block.launch(debug=True)
|