Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import time | |
import openai | |
# from langchain.llms import OpenAI | |
import pickle | |
from huggingface_hub import hf_hub_download | |
from langchain.chat_models import ChatOpenAI | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import RetrievalQA, LLMChain | |
from langchain.prompts import PromptTemplate | |
embedding_file = "subtitle_year_faiss_openai.pkl" | |
with open(embedding_file, 'rb') as f: | |
VectorStore = pickle.load(f) | |
""" initialize all the tools """ | |
template = """ | |
You are a knowledgeable assistant of Chartis' report and you are cautious about the answer you are giving. You will refuse to answer any questions that may generate an answer that violates the Open AI policy, or is not related to the given documents. | |
Given the user input question: {question} | |
• If the question can be inferred from the provided context, use the context to formulate your answer. | |
• If the question cannot be answered based on the context, simply state that you don't know. Do not provide inaccurate or made-up information. | |
Your answers should be: | |
• Direct and succinct. | |
• Accurate and directly addressing the user's questions. | |
{context} | |
Helpful Answer:""" | |
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"], template=template,) | |
OPENAI_API_KEY = '' | |
def get_opeanai_key(openai_key): | |
global OPENAI_API_KEY | |
OPENAI_API_KEY=openai_key | |
return {chatbot_col: gr.Column(visible=True)} | |
def slow_echo(usr_message, chat_history): | |
if OPENAI_API_KEY=='': | |
return 'Invalid or empty OPENAI_API_KEY', chat_history | |
chat_model = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo', openai_api_key=OPENAI_API_KEY) | |
# customized memory | |
memory = ConversationBufferMemory( | |
return_messages=True, | |
output_key='result' | |
) | |
answer_chain = RetrievalQA.from_chain_type( | |
chat_model, | |
retriever=VectorStore.as_retriever(search_type="similarity"), | |
memory = memory, | |
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}, | |
return_source_documents=True | |
) | |
try: | |
# Attempt to get a response from the OpenAI model | |
bot_result = answer_chain({"query": usr_message}) | |
bot_response = bot_result['result'] | |
source_doc = [bot_result['source_documents'][i].metadata['title'] for i in range(len(bot_result))] | |
source_page = [str(bot_result['source_documents'][i].metadata['page']+1) for i in range(len(bot_result))] | |
# formated output | |
source_print = {} | |
for i in range(len(source_doc)): | |
if source_doc[i] in source_print: | |
source_print[source_doc[i]] = source_print[source_doc[i]] + ', ' + source_page[i] | |
else: | |
source_print[source_doc[i]] = 'page: '+ source_page[i] | |
bot_response = bot_response + '\n Source:' | |
for doc, page in source_print.items(): | |
bot_response += '\n' + doc + ': ' + page | |
chat_history.append((usr_message, bot_response)) | |
time.sleep(1) | |
yield "", chat_history | |
except openai.error.OpenAIError as e: | |
# Handle OpenAI-specific errors | |
error_message = f"OpenAI API Error: {e}" | |
print(error_message) | |
return error_message, chat_history | |
except Exception as e: | |
# Handle other unexpected errors | |
error_message = f"Unexpected error: {e}" | |
print(error_message) | |
return error_message, chat_history | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Chartis Chatbot Demo | |
Please provide your own GPT key below first and press submit to play with the chatbot! | |
""") | |
openai_gpt_key = gr.Textbox(label="OpenAI Key", value="", type="password", placeholder="sk-") | |
btn = gr.Button(value="Submit") | |
with gr.Column(visible=False) as chatbot_col: | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(label='Type in your questions about Chartis here and press Enter!', | |
placeholder='Type in your questions.', scale=7) | |
clear = gr.ClearButton([msg, chatbot]) | |
msg.submit(slow_echo, [msg, chatbot], [msg, chatbot]) | |
btn.click(get_opeanai_key, inputs=[openai_gpt_key], outputs=[chatbot_col]) | |
demo.queue().launch(debug=True) | |