Sean-Case
commited on
Commit
•
6a76923
1
Parent(s):
d53332d
Added temperature slider, more stringent checks for document relevance
Browse files- app.py +3 -2
- chatfuncs/chatfuncs.py +35 -57
app.py
CHANGED
@@ -237,6 +237,7 @@ with block:
|
|
237 |
|
238 |
with gr.Tab("Advanced features"):
|
239 |
out_passages = gr.Slider(minimum=1, value = 2, maximum=10, step=1, label="Choose number of passages to retrieve from the document. Numbers greater than 2 may lead to increased hallucinations or input text being truncated.")
|
|
|
240 |
with gr.Row():
|
241 |
model_choice = gr.Radio(label="Choose a chat model", value="Flan Alpaca (small, fast)", choices = ["Flan Alpaca (small, fast)", "Mistral Open Orca (larger, slow)"])
|
242 |
change_model_button = gr.Button(value="Load model", scale=0)
|
@@ -281,14 +282,14 @@ with block:
|
|
281 |
# Click/enter to send message action
|
282 |
response_click = submit.click(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages], outputs=[chat_history_state, sources, instruction_prompt_out], queue=False, api_name="retrieval").\
|
283 |
then(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
|
284 |
-
then(chatf.produce_streaming_answer_chatbot, inputs=[chatbot, instruction_prompt_out, model_type_state], outputs=chatbot)
|
285 |
response_click.then(chatf.highlight_found_text, [chatbot, sources], [sources]).\
|
286 |
then(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
|
287 |
then(lambda: chatf.restore_interactivity(), None, [message], queue=False)
|
288 |
|
289 |
response_enter = message.submit(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages], outputs=[chat_history_state, sources, instruction_prompt_out], queue=False).\
|
290 |
then(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
|
291 |
-
then(chatf.produce_streaming_answer_chatbot, [chatbot, instruction_prompt_out, model_type_state], chatbot)
|
292 |
response_enter.then(chatf.highlight_found_text, [chatbot, sources], [sources]).\
|
293 |
then(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
|
294 |
then(lambda: chatf.restore_interactivity(), None, [message], queue=False)
|
|
|
237 |
|
238 |
with gr.Tab("Advanced features"):
|
239 |
out_passages = gr.Slider(minimum=1, value = 2, maximum=10, step=1, label="Choose number of passages to retrieve from the document. Numbers greater than 2 may lead to increased hallucinations or input text being truncated.")
|
240 |
+
temp_slide = gr.Slider(minimum=0.1, value = 0.1, maximum=1, step=0.1, label="Choose temperature setting for response generation.")
|
241 |
with gr.Row():
|
242 |
model_choice = gr.Radio(label="Choose a chat model", value="Flan Alpaca (small, fast)", choices = ["Flan Alpaca (small, fast)", "Mistral Open Orca (larger, slow)"])
|
243 |
change_model_button = gr.Button(value="Load model", scale=0)
|
|
|
282 |
# Click/enter to send message action
|
283 |
response_click = submit.click(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages], outputs=[chat_history_state, sources, instruction_prompt_out], queue=False, api_name="retrieval").\
|
284 |
then(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
|
285 |
+
then(chatf.produce_streaming_answer_chatbot, inputs=[chatbot, instruction_prompt_out, model_type_state, temp_slide], outputs=chatbot)
|
286 |
response_click.then(chatf.highlight_found_text, [chatbot, sources], [sources]).\
|
287 |
then(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
|
288 |
then(lambda: chatf.restore_interactivity(), None, [message], queue=False)
|
289 |
|
290 |
response_enter = message.submit(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages], outputs=[chat_history_state, sources, instruction_prompt_out], queue=False).\
|
291 |
then(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
|
292 |
+
then(chatf.produce_streaming_answer_chatbot, [chatbot, instruction_prompt_out, model_type_state, temp_slide], chatbot)
|
293 |
response_enter.then(chatf.highlight_found_text, [chatbot, sources], [sources]).\
|
294 |
then(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
|
295 |
then(lambda: chatf.restore_interactivity(), None, [message], queue=False)
|
chatfuncs/chatfuncs.py
CHANGED
@@ -158,6 +158,9 @@ class CtransGenGenerationConfig:
|
|
158 |
self.batch_size = batch_size
|
159 |
self.reset = reset
|
160 |
|
|
|
|
|
|
|
161 |
# Vectorstore funcs
|
162 |
|
163 |
def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
|
@@ -220,23 +223,6 @@ QUESTION: {question}
|
|
220 |
|
221 |
Response:"""
|
222 |
|
223 |
-
instruction_prompt_template_openllama = """Answer the QUESTION using information from the following CONTENT.
|
224 |
-
QUESTION - {question}
|
225 |
-
CONTENT - {summaries}
|
226 |
-
Answer:"""
|
227 |
-
|
228 |
-
instruction_prompt_template_platypus = """### Instruction:
|
229 |
-
Answer the QUESTION using information from the following CONTENT.
|
230 |
-
CONTENT: {summaries}
|
231 |
-
QUESTION: {question}
|
232 |
-
### Response:"""
|
233 |
-
|
234 |
-
instruction_prompt_template_wizard_orca_quote = """### HUMAN:
|
235 |
-
Quote text from the CONTENT to answer the QUESTION below.
|
236 |
-
CONTENT - {summaries}
|
237 |
-
QUESTION - {question}
|
238 |
-
### RESPONSE:
|
239 |
-
"""
|
240 |
|
241 |
instruction_prompt_template_wizard_orca = """### HUMAN:
|
242 |
Answer the QUESTION below based on the CONTENT. Only refer to CONTENT that directly answers the question.
|
@@ -266,15 +252,6 @@ CONTENT: {summaries}
|
|
266 |
### Response:
|
267 |
"""
|
268 |
|
269 |
-
instruction_prompt_template_orca_rev = """
|
270 |
-
### System:
|
271 |
-
You are an AI assistant that follows instruction extremely well. Help as much as you can.
|
272 |
-
### User:
|
273 |
-
Answer the QUESTION with a short response using information from the following CONTENT.
|
274 |
-
QUESTION: {question}
|
275 |
-
CONTENT: {summaries}
|
276 |
-
|
277 |
-
### Response:"""
|
278 |
|
279 |
instruction_prompt_mistral_orca = """<|im_start|>system\n
|
280 |
You are an AI assistant that follows instruction extremely well. Help as much as you can.
|
@@ -284,23 +261,6 @@ CONTENT: {summaries}
|
|
284 |
QUESTION: {question}\n
|
285 |
Answer:<|im_end|>"""
|
286 |
|
287 |
-
instruction_prompt_tinyllama_orca = """<|im_start|>system\n
|
288 |
-
You are an AI assistant that follows instruction extremely well. Help as much as you can.
|
289 |
-
<|im_start|>user\n
|
290 |
-
Answer the QUESTION using information from the following CONTENT. Only quote text that directly answers the question and nothing more. If you can't find an answer to the question, respond with "Sorry, I can't find an answer to that question.".
|
291 |
-
CONTENT: {summaries}
|
292 |
-
QUESTION: {question}\n
|
293 |
-
Answer:<|im_end|>"""
|
294 |
-
|
295 |
-
instruction_prompt_marx = """
|
296 |
-
### HUMAN:
|
297 |
-
Answer the QUESTION using information from the following CONTENT.
|
298 |
-
CONTENT: {summaries}
|
299 |
-
QUESTION: {question}
|
300 |
-
|
301 |
-
### RESPONSE:
|
302 |
-
"""
|
303 |
-
|
304 |
if model_type == "Flan Alpaca (small, fast)":
|
305 |
INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_template_alpaca, input_variables=['question', 'summaries'])
|
306 |
elif model_type == "Mistral Open Orca (larger, slow)":
|
@@ -322,9 +282,16 @@ def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt, content
|
|
322 |
|
323 |
|
324 |
docs_keep_as_doc, doc_df, docs_keep_out = hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val = 25, out_passages = out_passages,
|
325 |
-
vec_score_cut_off =
|
326 |
#vectorstore=globals()["vectorstore"], embeddings=globals()["embeddings"])
|
327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
# Expand the found passages to the neighbouring context
|
329 |
file_type = determine_file_type(doc_df['meta_url'][0])
|
330 |
|
@@ -332,8 +299,6 @@ def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt, content
|
|
332 |
if (file_type != ".csv") & (file_type != ".xlsx"):
|
333 |
docs_keep_as_doc, doc_df = get_expanded_passages(vectorstore, docs_keep_out, width=3)
|
334 |
|
335 |
-
if docs_keep_as_doc == []:
|
336 |
-
{"answer": "I'm sorry, I couldn't find a relevant answer to this question.", "sources":"I'm sorry, I couldn't find a relevant source for this question."}
|
337 |
|
338 |
|
339 |
# Build up sources content to add to user display
|
@@ -380,11 +345,21 @@ def create_full_prompt(user_input, history, extracted_memory, vectorstore, embed
|
|
380 |
|
381 |
print("Output history is:")
|
382 |
print(history)
|
|
|
|
|
|
|
383 |
|
384 |
return history, docs_content_string, instruction_prompt_out
|
385 |
|
386 |
# Chat functions
|
387 |
-
def produce_streaming_answer_chatbot(history, full_prompt, model_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
#print("Model type is: ", model_type)
|
389 |
|
390 |
#if not full_prompt.strip():
|
@@ -410,6 +385,9 @@ def produce_streaming_answer_chatbot(history, full_prompt, model_type):
|
|
410 |
temperature=temperature,
|
411 |
top_k=top_k
|
412 |
)
|
|
|
|
|
|
|
413 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
414 |
t.start()
|
415 |
|
@@ -437,6 +415,7 @@ def produce_streaming_answer_chatbot(history, full_prompt, model_type):
|
|
437 |
tokens = model.tokenize(full_prompt)
|
438 |
|
439 |
gen_config = CtransGenGenerationConfig()
|
|
|
440 |
|
441 |
print(vars(gen_config))
|
442 |
|
@@ -502,6 +481,8 @@ def create_doc_df(docs_keep_out):
|
|
502 |
page_section=[]
|
503 |
score=[]
|
504 |
|
|
|
|
|
505 |
|
506 |
|
507 |
for item in docs_keep_out:
|
@@ -530,6 +511,7 @@ def hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val, out_p
|
|
530 |
|
531 |
#vectorstore=globals()["vectorstore"]
|
532 |
#embeddings=globals()["embeddings"]
|
|
|
533 |
|
534 |
|
535 |
docs = vectorstore.similarity_search_with_score(new_question_kworded, k=k_val)
|
@@ -545,21 +527,15 @@ def hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val, out_p
|
|
545 |
score_more_limit = pd.Series(docs_scores) < vec_score_cut_off
|
546 |
docs_keep = list(compress(docs, score_more_limit))
|
547 |
|
548 |
-
if docs_keep
|
549 |
-
|
550 |
-
docs_content = []
|
551 |
-
docs_url = []
|
552 |
-
return docs_keep_as_doc, docs_content, docs_url
|
553 |
|
554 |
# Only keep sources that are at least 100 characters long
|
555 |
length_more_limit = pd.Series(docs_len) >= 100
|
556 |
docs_keep = list(compress(docs_keep, length_more_limit))
|
557 |
|
558 |
-
if docs_keep
|
559 |
-
|
560 |
-
docs_content = []
|
561 |
-
docs_url = []
|
562 |
-
return docs_keep_as_doc, docs_content, docs_url
|
563 |
|
564 |
docs_keep_as_doc = [x[0] for x in docs_keep]
|
565 |
docs_keep_length = len(docs_keep_as_doc)
|
@@ -763,6 +739,8 @@ def get_expanded_passages(vectorstore, docs, width):
|
|
763 |
expanded_doc = (Document(page_content=content_str[0], metadata=meta_full[0]), score)
|
764 |
expanded_docs.append(expanded_doc)
|
765 |
|
|
|
|
|
766 |
doc_df = create_doc_df(expanded_docs) # Assuming you've defined the 'create_doc_df' function elsewhere
|
767 |
|
768 |
return expanded_docs, doc_df
|
|
|
158 |
self.batch_size = batch_size
|
159 |
self.reset = reset
|
160 |
|
161 |
+
def update_temp(self, new_value):
|
162 |
+
self.temperature = new_value
|
163 |
+
|
164 |
# Vectorstore funcs
|
165 |
|
166 |
def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
|
|
|
223 |
|
224 |
Response:"""
|
225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
instruction_prompt_template_wizard_orca = """### HUMAN:
|
228 |
Answer the QUESTION below based on the CONTENT. Only refer to CONTENT that directly answers the question.
|
|
|
252 |
### Response:
|
253 |
"""
|
254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
instruction_prompt_mistral_orca = """<|im_start|>system\n
|
257 |
You are an AI assistant that follows instruction extremely well. Help as much as you can.
|
|
|
261 |
QUESTION: {question}\n
|
262 |
Answer:<|im_end|>"""
|
263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
if model_type == "Flan Alpaca (small, fast)":
|
265 |
INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_template_alpaca, input_variables=['question', 'summaries'])
|
266 |
elif model_type == "Mistral Open Orca (larger, slow)":
|
|
|
282 |
|
283 |
|
284 |
docs_keep_as_doc, doc_df, docs_keep_out = hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val = 25, out_passages = out_passages,
|
285 |
+
vec_score_cut_off = 0.85, vec_weight = 1, bm25_weight = 1, svm_weight = 1)#,
|
286 |
#vectorstore=globals()["vectorstore"], embeddings=globals()["embeddings"])
|
287 |
|
288 |
+
#print(docs_keep_as_doc)
|
289 |
+
#print(doc_df)
|
290 |
+
if (not docs_keep_as_doc) | (doc_df.empty):
|
291 |
+
sorry_prompt = """Say 'Sorry, there is no relevant information to answer this question.'.
|
292 |
+
RESPONSE:"""
|
293 |
+
return sorry_prompt, "No relevant sources found.", new_question_kworded
|
294 |
+
|
295 |
# Expand the found passages to the neighbouring context
|
296 |
file_type = determine_file_type(doc_df['meta_url'][0])
|
297 |
|
|
|
299 |
if (file_type != ".csv") & (file_type != ".xlsx"):
|
300 |
docs_keep_as_doc, doc_df = get_expanded_passages(vectorstore, docs_keep_out, width=3)
|
301 |
|
|
|
|
|
302 |
|
303 |
|
304 |
# Build up sources content to add to user display
|
|
|
345 |
|
346 |
print("Output history is:")
|
347 |
print(history)
|
348 |
+
|
349 |
+
print("Final prompt to model is:")
|
350 |
+
print(instruction_prompt_out)
|
351 |
|
352 |
return history, docs_content_string, instruction_prompt_out
|
353 |
|
354 |
# Chat functions
|
355 |
+
def produce_streaming_answer_chatbot(history, full_prompt, model_type,
|
356 |
+
temperature=temperature,
|
357 |
+
max_new_tokens=max_new_tokens,
|
358 |
+
sample=sample,
|
359 |
+
repetition_penalty=repetition_penalty,
|
360 |
+
top_p=top_p,
|
361 |
+
top_k=top_k
|
362 |
+
):
|
363 |
#print("Model type is: ", model_type)
|
364 |
|
365 |
#if not full_prompt.strip():
|
|
|
385 |
temperature=temperature,
|
386 |
top_k=top_k
|
387 |
)
|
388 |
+
|
389 |
+
print(generate_kwargs)
|
390 |
+
|
391 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
392 |
t.start()
|
393 |
|
|
|
415 |
tokens = model.tokenize(full_prompt)
|
416 |
|
417 |
gen_config = CtransGenGenerationConfig()
|
418 |
+
gen_config.update_temp(temperature)
|
419 |
|
420 |
print(vars(gen_config))
|
421 |
|
|
|
481 |
page_section=[]
|
482 |
score=[]
|
483 |
|
484 |
+
doc_df = pd.DataFrame()
|
485 |
+
|
486 |
|
487 |
|
488 |
for item in docs_keep_out:
|
|
|
511 |
|
512 |
#vectorstore=globals()["vectorstore"]
|
513 |
#embeddings=globals()["embeddings"]
|
514 |
+
doc_df = pd.DataFrame()
|
515 |
|
516 |
|
517 |
docs = vectorstore.similarity_search_with_score(new_question_kworded, k=k_val)
|
|
|
527 |
score_more_limit = pd.Series(docs_scores) < vec_score_cut_off
|
528 |
docs_keep = list(compress(docs, score_more_limit))
|
529 |
|
530 |
+
if not docs_keep:
|
531 |
+
return [], pd.DataFrame(), []
|
|
|
|
|
|
|
532 |
|
533 |
# Only keep sources that are at least 100 characters long
|
534 |
length_more_limit = pd.Series(docs_len) >= 100
|
535 |
docs_keep = list(compress(docs_keep, length_more_limit))
|
536 |
|
537 |
+
if not docs_keep:
|
538 |
+
return [], pd.DataFrame(), []
|
|
|
|
|
|
|
539 |
|
540 |
docs_keep_as_doc = [x[0] for x in docs_keep]
|
541 |
docs_keep_length = len(docs_keep_as_doc)
|
|
|
739 |
expanded_doc = (Document(page_content=content_str[0], metadata=meta_full[0]), score)
|
740 |
expanded_docs.append(expanded_doc)
|
741 |
|
742 |
+
doc_df = pd.DataFrame()
|
743 |
+
|
744 |
doc_df = create_doc_df(expanded_docs) # Assuming you've defined the 'create_doc_df' function elsewhere
|
745 |
|
746 |
return expanded_docs, doc_df
|