docs-bot / pages /jarvis.py
Huzaifa367's picture
Update pages/jarvis.py
2f49f39 verified
raw
history blame
4.63 kB
import streamlit as st
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain_community.llms import HuggingFaceEndpoint
from pathlib import Path
import chromadb
from unidecode import unidecode
from transformers import AutoTokenizer
import transformers
import torch
import tqdm
import accelerate
import re
# Function to load PDF document and create doc splits
def load_doc(list_file_path, chunk_size, chunk_overlap):
loaders = [PyPDFLoader(x) for x in list_file_path]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
doc_splits = text_splitter.split_documents(pages)
return doc_splits
# Function to create vector database
def create_db(splits, collection_name):
embedding = HuggingFaceEmbeddings()
new_client = chromadb.EphemeralClient()
vectordb = Chroma.from_documents(
documents=splits,
embedding=embedding,
client=new_client,
collection_name=collection_name,
)
return vectordb
# Initialize Langchain LLM chain
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
llm = HuggingFaceEndpoint(
repo_id=llm_model,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
load_in_8bit=True,
)
# Add other LLM models initialization conditions here...
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
retriever = vector_db.as_retriever()
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
return qa_chain
# Function to process uploaded PDFs and initialize the database
def process_documents(list_file_obj, chunk_size, chunk_overlap):
list_file_path = [x.name for x in list_file_obj if x is not None]
collection_name = create_collection_name(list_file_path[0])
doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
vector_db = create_db(doc_splits, collection_name)
return vector_db
# Streamlit app
def main():
st.title("PDF-based Chatbot")
st.write("Ask any questions about your PDF documents")
# Step 1: Upload PDF documents
uploaded_files = st.file_uploader("Upload your PDF documents (single or multiple)", type=["pdf"], accept_multiple_files=True)
# Step 2: Process documents and initialize vector database
if uploaded_files:
chunk_size = st.slider("Chunk size", min_value=100, max_value=1000, value=600, step=20)
chunk_overlap = st.slider("Chunk overlap", min_value=10, max_value=200, value=40, step=10)
if st.button("Generate Vector Database"):
vector_db = process_documents(uploaded_files, chunk_size, chunk_overlap)
st.success("Vector database generated successfully!")
# Step 3: Initialize QA chain with selected LLM model
st.header("Initialize Question Answering (QA) Chain")
llm_model = st.selectbox("Choose LLM Model", list_llm_simple)
temperature = st.slider("Temperature", min_value=0.01, max_value=1.0, value=0.7, step=0.1)
max_tokens = st.slider("Max Tokens", min_value=224, max_value=4096, value=1024, step=32)
top_k = st.slider("Top-k Samples", min_value=1, max_value=10, value=3, step=1)
if st.button("Initialize QA Chain"):
qa_chain = initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db)
st.success("QA Chain initialized successfully!")
# Step 4: Chatbot interaction
st.header("Chatbot")
message = st.text_input("Type your message here")
if st.button("Submit"):
response = qa_chain(message)
st.write(f"Chatbot Response: {response['answer']}")
if __name__ == "__main__":
main()