AISandbox / qa /qa.py
fracapuano
add: spinner for querying document, increment in number of sources referenced
07b1b19
raw
history blame
8.43 kB
import streamlit as st
from openai.error import OpenAIError
from .utils import *
from typing import Text, Union
multiple_files = True
def query_pipeline(index:VectorStore, query:Text, stream_answer:bool=False, n_sources:int=5)->Text:
"""This function reproduces the querying pipeline considering a given input index."""
# retrieving the most relevant pieces of information within the knowledge base
sources = search_docs(index, query=query, k=n_sources)
# getting the answer, all at once
answer = get_answer(sources, query=query, stream_answer=stream_answer)["output_text"]
return answer
def toggle_process_document():
"""Toggles the greenlight for the next step in the pipeline, i.e. processing the document."""
if "processing_document_greenlight" not in st.session_state:
st.session_state["processing_document_greenlight"] = True
st.session_state["processing_document_greenlight"] = not st.session_state["processing_document_greenlight"]
def register_new_file_name(file_name):
"""
Registers a new file name in the internal session state.
"""
if "uploaded_file_names" not in st.session_state:
st.session_state["uploaded_file_names"] = []
st.session_state["uploaded_file_names"].append(file_name)
def clear_index():
"""
Clears the index from the internal session state.
This is a non reversible operation.
"""
if "index" in st.session_state:
del globals()["index"]
def clear_session_state():
"""
Clears the session state iterating over keys.
This is a non reversible operation.
"""
for k in st.session_state.keys():
del st.session_state[k]
def register_new_file(new_file):
"""
Registers a new file in the internal session state.
"""
if "uploaded_files" not in st.session_state:
st.session_state["uploaded_files"] = []
st.session_state["uploaded_files"].extend(new_file)
def clear_all_files():
"""Removes all uploaded files from the interal session state."""
st.session_state["uploaded_files"] = []
def append_uploaded_files(file):
"""Appends the uploaded files to the internal session state."""
st.session_state.get("uploaded_files", []).extend(file)
def set_openai_api_key(api_key:Text)->bool:
"""Sets the internal OpenAI API key to the given value.
Args:
api_key (Text): OpenAI API key
"""
if not check_openai_api_key(api_key=api_key):
raise ValueError("Invalid OpenAI API key! Please provide a valid key.")
st.session_state["OPENAI_API_KEY"] = api_key
st.session_state["api_key_configured"] = True
return True
def parse_file(file:Union[PDFFile, DocxFile, TxtFile, CodeFile]) -> None:
"""Converts a file to a document using specialized parsers."""
if file.name.endswith(".pdf"):
doc = parse_pdf(file)
elif file.name.endswith(".docx"):
doc = parse_docx(file)
elif file.name.split["."][1] in [".txt", ".py", ".json", ".html", ".css", ".md" ]:
doc = parse_txt(file)
else:
st.error("File type not yet supported! Supported files: [.pdf, .docx, .txt, .py, .json, .html, .css, .md]")
doc = None
return doc
# this function can be used to define a single doc processing pipeline
# def document_embedding_pipeline(file:Union[PDFFile, DocxFile, TxtFile, CodeFile]) -> None:
def qa_main():
"""Main function for the QA app."""
st.title("Chat with a file 💬📖")
st.write("Just upload something using and start chatting with a version of GPT4 that has read the file!")
# OpenAI API Key - TODO: consider adding a key valid for everyone
# st.header("Configure OpenAI API Key")
# st.warning('Please enter your OpenAI API Key!', icon='⚠️')
# uncomment the following lines to add a user-specific key
# user_secret = st.text_input(
# "Insert your OpenAI API key here ([get your API key](https://platform.openai.com/account/api-keys)).",
# type="password",
# placeholder="Paste your OpenAI API key here (sk-...)",
# help="You can get your API key from https://platform.openai.com/account/api-keys.",
# value=st.session_state.get("OPENAI_API_KEY", ""),
# )
user_secret = st.secrets["OPENAI_API_KEY"]
if user_secret:
if set_openai_api_key(user_secret):
# removing this when the OpenAI API key is hardcoded
# st.success('OpenAI API key successfully accessed!', icon='✅')
# greenlight for next step, i.e. uploading the document to chat with
st.session_state["upload_document_greenlight"] = True
if st.session_state.get("upload_document_greenlight"):
# File that needs to be queried
st.header("Upload a file")
st.file_uploader(
"Upload a pdf, docx, or txt file (scanned documents not supported)",
type=["pdf", "docx", "txt", "py", "json", "html", "css", "md"],
help="Scanned documents are not supported yet 🥲",
accept_multiple_files=multiple_files,
#on_change=toggle_process_document,
key="uploaded_file"
)
documents = {}
indexes = {}
for file in st.session_state["uploaded_file"]:
parsed_file = parse_file(file)
# converts the files into a list of documents
document = text_to_docs(pages=tuple(parsed_file), file_name=file.name)
documents[file.name] = document
with st.spinner(f"Indexing {file.name} (might take some time)"):
try:
# indexing the document uploaded
indexes[file.name] = embed_docs(file_name=file.name, _docs=tuple(document))
except OpenAIError as e:
st.error("OpenAI error encountered: ", e._message)
if len(documents)>1:
# documents to be indexed when providing the query
st.multiselect(
label="Select the documents to be indexed",
options=list(documents.keys()),
key="multiselect_documents_choices",
)
elif len(documents)==1:
st.session_state["multiselect_documents_choices"] = [list(documents.keys())[0]]
# this is the code that actually performs the chat process
if "messages" not in st.session_state: # checking if there is any cache history
st.session_state["messages"] = []
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"], unsafe_allow_html=True)
if prompt:=st.chat_input("Ask the document something..."):
if prompt=="1":
prompt="What is this document about?"
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
# full_response will store every question asked to all the document(s) considered
full_response = ""
message_placeholder = st.empty()
# asking the same question to all of the documents considered
for chat_document in st.session_state["multiselect_documents_choices"]:
# keeping track of what is asked to what document
full_response += \
f"<i>Asking</i> <b>{chat_document}</b> <i>question</i> <b>{prompt}</b></i><br>"
message_placeholder.markdown(full_response, unsafe_allow_html=True)
with st.spinner("Querying the document..."):
# retrieving the vector store associated to the chat document considered
chat_index = indexes[chat_document]
# producing the answer considered, live
for answer_bit in query_pipeline(chat_index, prompt, stream_answer=True, n_sources=20):
full_response += answer_bit
message_placeholder.markdown(full_response + "▌", unsafe_allow_html=True)
# appending a final entering
full_response += "<br>"
message_placeholder.markdown(full_response, unsafe_allow_html=True)
# appending the final response obtained after having asked all the documents
st.session_state.messages.append({"role": "assistant", "content": full_response})