from langchain.chains import ConversationalRetrievalChain |
from langchain.chains.question_answering import load_qa_chain |
from langchain.chains import RetrievalQA |
from langchain.memory import ConversationBufferMemory |
from langchain.memory import ConversationTokenBufferMemory |
from langchain.llms import HuggingFacePipeline |
from langchain.prompts import PromptTemplate |
from langchain.embeddings import HuggingFaceEmbeddings |
from langchain.text_splitter import RecursiveCharacterTextSplitter |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
from langchain.vectorstores import Chroma |
from chromadb.utils import embedding_functions |
from langchain.embeddings import HuggingFaceBgeEmbeddings |
from langchain.document_loaders import ( |
CSVLoader, |
DirectoryLoader, |
GitLoader, |
NotebookLoader, |
OnlinePDFLoader, |
PythonLoader, |
TextLoader, |
UnstructuredFileLoader, |
UnstructuredHTMLLoader, |
UnstructuredPDFLoader, |
UnstructuredWordDocumentLoader, |
WebBaseLoader, |
PyPDFLoader, |
UnstructuredMarkdownLoader, |
UnstructuredEPubLoader, |
UnstructuredHTMLLoader, |
UnstructuredPowerPointLoader, |
UnstructuredODTLoader, |
NotebookLoader, |
UnstructuredFileLoader |
) |
from transformers import ( |
AutoModelForCausalLM, |
AutoTokenizer, |
StoppingCriteria, |
StoppingCriteriaList, |
pipeline, |
GenerationConfig, |
TextStreamer, |
pipeline |
) |
from langchain.llms import HuggingFaceHub |
import torch |
from transformers import BitsAndBytesConfig |
import os |
from langchain.llms import CTransformers |
import streamlit as st |
from langchain.document_loaders.base import BaseLoader |
from langchain.schema import Document |
import gradio as gr |
import tempfile |
import timeit |
import textwrap |
"csv": (CSVLoader, {"encoding": "utf-8"}), |
"doc": (UnstructuredWordDocumentLoader, {}), |
"docx": (UnstructuredWordDocumentLoader, {}), |
"epub": (UnstructuredEPubLoader, {}), |
"html": (UnstructuredHTMLLoader, {}), |
"md": (UnstructuredMarkdownLoader, {}), |
"odt": (UnstructuredODTLoader, {}), |
"pdf": (PyPDFLoader, {}), |
"ppt": (UnstructuredPowerPointLoader, {}), |
"pptx": (UnstructuredPowerPointLoader, {}), |
"txt": (TextLoader, {"encoding": "utf8"}), |
"ipynb": (NotebookLoader, {}), |
"py": (PythonLoader, {}), |
} |
def load_model(): |
config = {'max_new_tokens': 1024, |
'repetition_penalty': 1.1, |
'temperature': 0.1, |
'top_k': 50, |
'top_p': 0.9, |
'stream': True, |
'threads': int(os.cpu_count() / 2) |
} |
llm = CTransformers( |
model = "TheBloke/zephyr-7B-beta-GGUF", |
model_file = "zephyr-7b-beta.Q4_0.gguf", |
callbacks=[StreamingStdOutCallbackHandler()], |
lib="avx2", |
**config |
) |
return llm |
def create_vector_database(loaded_documents): |
""" |
Creates a vector database using document loaders and embeddings. |
This function loads data from PDF, markdown and text files in the 'data/' directory, |
splits the loaded documents into chunks, transforms them into embeddings using HuggingFace, |
and finally persists the embeddings into a Chroma vector database. |
""" |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=30, length_function = len) |
chunked_documents = text_splitter.split_documents(loaded_documents) |
model_name = "BAAI/bge-large-en" |
model_kwargs = {'device': 'cpu'} |
encode_kwargs = {'normalize_embeddings': False} |
embeddings = HuggingFaceBgeEmbeddings( |
model_name=model_name, |
model_kwargs=model_kwargs, |
encode_kwargs=encode_kwargs |
) |
persist_directory = 'db' |
db = Chroma.from_documents( |
documents=chunked_documents, |
embedding=embeddings, |
persist_directory=persist_directory |
) |
db.persist() |
return db |
def set_custom_prompt(): |
""" |
Prompt template for retrieval for each vectorstore |
""" |
prompt_template = """Use the following pieces of information to answer the user's question. |
If you don't know the answer, just say that you don't know, don't try to make up an answer. |
Context: {context} |
Question: {question} |
Only return the helpful answer below and nothing else. |
Helpful answer: |
""" |
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) |
return prompt |
def create_chain(llm, prompt, db): |
""" |
Creates a Retrieval Question-Answering (QA) chain using a given language model, prompt, and database. |
This function initializes a ConversationalRetrievalChain object with a specific chain type and configurations, |
and returns this chain. The retriever is set up to return the top 3 results (k=3). |
Args: |
llm (any): The language model to be used in the RetrievalQA. |
prompt (str): The prompt to be used in the chain type. |
db (any): The database to be used as the |
retriever. |
Returns: |
ConversationalRetrievalChain: The initialized conversational chain. |
""" |
memory = ConversationTokenBufferMemory(llm=llm, memory_key="chat_history", return_messages=True, input_key='question', output_key='answer') |
chain = RetrievalQA.from_chain_type(llm=llm, |
chain_type='stuff', |
retriever=db.as_retriever(search_kwargs={'k': 3}), |
return_source_documents=True |
) |
return chain |
def create_retrieval_qa_bot(loaded_documents): |
try: |
llm = load_model() |
except Exception as e: |
raise Exception(f"Failed to load model: {str(e)}") |
try: |
prompt = set_custom_prompt() |
except Exception as e: |
raise Exception(f"Failed to get prompt: {str(e)}") |
try: |
db = create_vector_database(loaded_documents) |
except Exception as e: |
raise Exception(f"Failed to get database: {str(e)}") |
try: |
qa = create_chain( |
llm=llm, prompt=prompt, db=db |
) |
except Exception as e: |
raise Exception(f"Failed to create retrieval QA chain: {str(e)}") |
return qa |
def wrap_text_preserve_newlines(text, width=110): |
lines = text.split('\n') |
wrapped_lines = [textwrap.fill(line, width=width) for line in lines] |
wrapped_text = '\n'.join(wrapped_lines) |
return wrapped_text |
def retrieve_bot_answer(query, loaded_documents): |
""" |
Retrieves the answer to a given query using a QA bot. |
This function creates an instance of a QA bot, passes the query to it, |
and returns the bot's response. |
Args: |
query (str): The question to be answered by the QA bot. |
Returns: |
dict: The QA bot's response, typically a dictionary with response details. |
""" |
qa_bot_instance = create_retrieval_qa_bot(loaded_documents) |
bot_response = qa_bot_instance({"query": query}) |
result = wrap_text_preserve_newlines(bot_response['result']) |
for source in bot_response["source_documents"]: |
sources.append(source.metadata['source']) |
return result, sources |
def main(): |
st.title("Docuverse") |
uploaded_files = st.file_uploader("Upload your documents", type=["pdf", "md", "txt", "csv", "py", "epub", "html", "ppt", "pptx", "doc", "docx", "odt", "ipynb"], accept_multiple_files=True) |
loaded_documents = [] |
if uploaded_files: |
with tempfile.TemporaryDirectory() as td: |
for uploaded_file in uploaded_files: |
st.write(f"Uploaded: {uploaded_file.name}") |
ext = os.path.splitext(uploaded_file.name)[-1][1:].lower() |
st.write(f"Uploaded: {ext}") |
loader_class, loader_args = FILE_LOADER_MAPPING[ext] |
file_path = os.path.join(td, uploaded_file.name) |
with open(file_path, 'wb') as temp_file: |
temp_file.write(uploaded_file.read()) |
loader = loader_class(file_path, **loader_args) |
loaded_documents.extend(loader.load()) |
else: |
st.warning(f"Unsupported file extension: {ext}") |
st.write("Chat with the Document:") |
query = st.text_input("Ask a question:") |
if st.button("Get Answer"): |
if query: |
try: |
start = timeit.default_timer() |
llm = load_model() |
prompt = set_custom_prompt() |
db = create_vector_database(loaded_documents) |
result, sources = retrieve_bot_answer(query,loaded_documents) |
end = timeit.default_timer() |
st.write("Elapsed time:") |
st.write(end - start) |
st.write("Bot Response:") |
st.write(result) |
st.write(sources) |
except Exception as e: |
st.error(f"An error occurred: {str(e)}") |
else: |
st.warning("Please enter a question.") |
if __name__ == "__main__": |
main() |