|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.chains.question_answering import load_qa_chain |
|
from langchain.chains import RetrievalQA |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.memory import ConversationTokenBufferMemory |
|
from langchain.llms import HuggingFacePipeline |
|
|
|
from langchain.prompts import PromptTemplate |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
|
from langchain.vectorstores import Chroma |
|
from chromadb.utils import embedding_functions |
|
from langchain.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain.document_loaders import ( |
|
CSVLoader, |
|
DirectoryLoader, |
|
GitLoader, |
|
NotebookLoader, |
|
OnlinePDFLoader, |
|
PythonLoader, |
|
TextLoader, |
|
UnstructuredFileLoader, |
|
UnstructuredHTMLLoader, |
|
UnstructuredPDFLoader, |
|
UnstructuredWordDocumentLoader, |
|
WebBaseLoader, |
|
PyPDFLoader, |
|
UnstructuredMarkdownLoader, |
|
UnstructuredEPubLoader, |
|
UnstructuredHTMLLoader, |
|
UnstructuredPowerPointLoader, |
|
UnstructuredODTLoader, |
|
NotebookLoader, |
|
UnstructuredFileLoader |
|
) |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
StoppingCriteria, |
|
StoppingCriteriaList, |
|
pipeline, |
|
GenerationConfig, |
|
TextStreamer, |
|
pipeline |
|
) |
|
from langchain.llms import HuggingFaceHub |
|
import torch |
|
from transformers import BitsAndBytesConfig |
|
import os |
|
from langchain.llms import CTransformers |
|
import streamlit as st |
|
from langchain.document_loaders.base import BaseLoader |
|
from langchain.schema import Document |
|
import gradio as gr |
|
import tempfile |
|
import timeit |
|
import textwrap |
|
|
|
FILE_LOADER_MAPPING = { |
|
"csv": (CSVLoader, {"encoding": "utf-8"}), |
|
"doc": (UnstructuredWordDocumentLoader, {}), |
|
"docx": (UnstructuredWordDocumentLoader, {}), |
|
"epub": (UnstructuredEPubLoader, {}), |
|
"html": (UnstructuredHTMLLoader, {}), |
|
"md": (UnstructuredMarkdownLoader, {}), |
|
"odt": (UnstructuredODTLoader, {}), |
|
"pdf": (PyPDFLoader, {}), |
|
"ppt": (UnstructuredPowerPointLoader, {}), |
|
"pptx": (UnstructuredPowerPointLoader, {}), |
|
"txt": (TextLoader, {"encoding": "utf8"}), |
|
"ipynb": (NotebookLoader, {}), |
|
"py": (PythonLoader, {}), |
|
|
|
} |
|
|
|
def load_model(): |
|
config = {'max_new_tokens': 1024, |
|
'repetition_penalty': 1.1, |
|
'temperature': 0.1, |
|
'top_k': 50, |
|
'top_p': 0.9, |
|
'stream': True, |
|
'threads': int(os.cpu_count() / 2) |
|
} |
|
|
|
llm = CTransformers( |
|
model = "TheBloke/zephyr-7B-beta-GGUF", |
|
model_file = "zephyr-7b-beta.Q4_0.gguf", |
|
callbacks=[StreamingStdOutCallbackHandler()], |
|
lib="avx2", |
|
**config |
|
|
|
|
|
|
|
) |
|
return llm |
|
|
|
def create_vector_database(loaded_documents): |
|
|
|
""" |
|
Creates a vector database using document loaders and embeddings. |
|
This function loads data from PDF, markdown and text files in the 'data/' directory, |
|
splits the loaded documents into chunks, transforms them into embeddings using HuggingFace, |
|
and finally persists the embeddings into a Chroma vector database. |
|
""" |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=30, length_function = len) |
|
chunked_documents = text_splitter.split_documents(loaded_documents) |
|
|
|
|
|
|
|
|
|
|
|
model_name = "BAAI/bge-large-en" |
|
model_kwargs = {'device': 'cpu'} |
|
encode_kwargs = {'normalize_embeddings': False} |
|
embeddings = HuggingFaceBgeEmbeddings( |
|
model_name=model_name, |
|
model_kwargs=model_kwargs, |
|
encode_kwargs=encode_kwargs |
|
) |
|
|
|
persist_directory = 'db' |
|
|
|
db = Chroma.from_documents( |
|
documents=chunked_documents, |
|
embedding=embeddings, |
|
persist_directory=persist_directory |
|
|
|
) |
|
db.persist() |
|
|
|
|
|
return db |
|
|
|
def set_custom_prompt(): |
|
""" |
|
Prompt template for retrieval for each vectorstore |
|
""" |
|
prompt_template = """Use the following pieces of information to answer the user's question. |
|
If you don't know the answer, just say that you don't know, don't try to make up an answer. |
|
Context: {context} |
|
Question: {question} |
|
Only return the helpful answer below and nothing else. |
|
Helpful answer: |
|
""" |
|
|
|
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) |
|
return prompt |
|
|
|
def create_chain(llm, prompt, db): |
|
""" |
|
Creates a Retrieval Question-Answering (QA) chain using a given language model, prompt, and database. |
|
This function initializes a ConversationalRetrievalChain object with a specific chain type and configurations, |
|
and returns this chain. The retriever is set up to return the top 3 results (k=3). |
|
Args: |
|
llm (any): The language model to be used in the RetrievalQA. |
|
prompt (str): The prompt to be used in the chain type. |
|
db (any): The database to be used as the |
|
retriever. |
|
Returns: |
|
ConversationalRetrievalChain: The initialized conversational chain. |
|
""" |
|
memory = ConversationTokenBufferMemory(llm=llm, memory_key="chat_history", return_messages=True, input_key='question', output_key='answer') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chain = RetrievalQA.from_chain_type(llm=llm, |
|
chain_type='stuff', |
|
retriever=db.as_retriever(search_kwargs={'k': 3}), |
|
return_source_documents=True |
|
) |
|
return chain |
|
|
|
def create_retrieval_qa_bot(loaded_documents): |
|
|
|
|
|
|
|
try: |
|
llm = load_model() |
|
except Exception as e: |
|
raise Exception(f"Failed to load model: {str(e)}") |
|
|
|
try: |
|
prompt = set_custom_prompt() |
|
except Exception as e: |
|
raise Exception(f"Failed to get prompt: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
db = create_vector_database(loaded_documents) |
|
except Exception as e: |
|
raise Exception(f"Failed to get database: {str(e)}") |
|
|
|
try: |
|
|
|
|
|
|
|
qa = create_chain( |
|
llm=llm, prompt=prompt, db=db |
|
) |
|
except Exception as e: |
|
raise Exception(f"Failed to create retrieval QA chain: {str(e)}") |
|
|
|
return qa |
|
|
|
def wrap_text_preserve_newlines(text, width=110): |
|
|
|
lines = text.split('\n') |
|
|
|
|
|
wrapped_lines = [textwrap.fill(line, width=width) for line in lines] |
|
|
|
|
|
wrapped_text = '\n'.join(wrapped_lines) |
|
|
|
return wrapped_text |
|
|
|
def retrieve_bot_answer(query, loaded_documents): |
|
""" |
|
Retrieves the answer to a given query using a QA bot. |
|
This function creates an instance of a QA bot, passes the query to it, |
|
and returns the bot's response. |
|
Args: |
|
query (str): The question to be answered by the QA bot. |
|
Returns: |
|
dict: The QA bot's response, typically a dictionary with response details. |
|
""" |
|
qa_bot_instance = create_retrieval_qa_bot(loaded_documents) |
|
|
|
bot_response = qa_bot_instance({"query": query}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = wrap_text_preserve_newlines(bot_response['result']) |
|
for source in bot_response["source_documents"]: |
|
sources.append(source.metadata['source']) |
|
return result, sources |
|
|
|
def main(): |
|
|
|
st.title("Docuverse") |
|
|
|
|
|
uploaded_files = st.file_uploader("Upload your documents", type=["pdf", "md", "txt", "csv", "py", "epub", "html", "ppt", "pptx", "doc", "docx", "odt", "ipynb"], accept_multiple_files=True) |
|
loaded_documents = [] |
|
|
|
if uploaded_files: |
|
|
|
with tempfile.TemporaryDirectory() as td: |
|
|
|
for uploaded_file in uploaded_files: |
|
st.write(f"Uploaded: {uploaded_file.name}") |
|
ext = os.path.splitext(uploaded_file.name)[-1][1:].lower() |
|
st.write(f"Uploaded: {ext}") |
|
|
|
|
|
if ext in FILE_LOADER_MAPPING: |
|
loader_class, loader_args = FILE_LOADER_MAPPING[ext] |
|
|
|
|
|
|
|
file_path = os.path.join(td, uploaded_file.name) |
|
with open(file_path, 'wb') as temp_file: |
|
temp_file.write(uploaded_file.read()) |
|
|
|
|
|
loader = loader_class(file_path, **loader_args) |
|
loaded_documents.extend(loader.load()) |
|
else: |
|
st.warning(f"Unsupported file extension: {ext}") |
|
|
|
|
|
st.write("Chat with the Document:") |
|
query = st.text_input("Ask a question:") |
|
|
|
if st.button("Get Answer"): |
|
if query: |
|
|
|
try: |
|
start = timeit.default_timer() |
|
llm = load_model() |
|
prompt = set_custom_prompt() |
|
|
|
db = create_vector_database(loaded_documents) |
|
|
|
result, sources = retrieve_bot_answer(query,loaded_documents) |
|
end = timeit.default_timer() |
|
st.write("Elapsed time:") |
|
st.write(end - start) |
|
|
|
|
|
st.write("Bot Response:") |
|
st.write(result) |
|
st.write(sources) |
|
except Exception as e: |
|
st.error(f"An error occurred: {str(e)}") |
|
else: |
|
st.warning("Please enter a question.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|