Huzaifa367 commited on
Commit
c6308ec
·
verified ·
1 Parent(s): 5bae4c3

Update pages/jarvis.py

Browse files
Files changed (1) hide show
  1. pages/jarvis.py +44 -68
pages/jarvis.py CHANGED
@@ -1,37 +1,35 @@
1
  import streamlit as st
 
 
2
  from langchain_community.document_loaders import PyPDFLoader
3
  from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain_community.vectorstores import Chroma
5
  from langchain.chains import ConversationalRetrievalChain
6
- from langchain_community.embeddings import HuggingFaceEmbeddings
7
- from langchain_community.llms import HuggingFacePipeline
8
- from langchain.chains import ConversationChain
9
- from langchain.memory import ConversationBufferMemory
10
  from langchain_community.llms import HuggingFaceEndpoint
11
- from pathlib import Path
12
- import chromadb
13
  from unidecode import unidecode
14
- from transformers import AutoTokenizer
15
- import transformers
16
- import torch
17
- import tqdm
18
- import accelerate
19
  import re
20
 
21
- # Function to load PDF document and create doc splits
 
 
 
 
 
 
 
 
22
  def load_doc(list_file_path, chunk_size, chunk_overlap):
23
  loaders = [PyPDFLoader(x) for x in list_file_path]
24
  pages = []
25
  for loader in loaders:
26
  pages.extend(loader.load())
27
- text_splitter = RecursiveCharacterTextSplitter(
28
- chunk_size=chunk_size,
29
- chunk_overlap=chunk_overlap
30
- )
31
  doc_splits = text_splitter.split_documents(pages)
32
  return doc_splits
33
 
34
- # Create vector database
35
  def create_db(splits, collection_name):
36
  embedding = HuggingFaceEmbeddings()
37
  new_client = chromadb.EphemeralClient()
@@ -39,87 +37,65 @@ def create_db(splits, collection_name):
39
  documents=splits,
40
  embedding=embedding,
41
  client=new_client,
42
- collection_name=collection_name,
43
- # persist_directory=default_persist_directory
44
  )
45
  return vectordb
46
 
47
-
48
- # Load vector database
49
- def load_db():
50
- embedding = HuggingFaceEmbeddings()
51
- vectordb = Chroma(
52
- # persist_directory=default_persist_directory,
53
- embedding_function=embedding)
54
- return vectordb
55
-
56
- # Initialize Langchain LLM chain
57
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
58
- if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
59
- llm = HuggingFaceEndpoint(
60
- repo_id=llm_model,
61
- temperature=temperature,
62
- max_new_tokens=max_tokens,
63
- top_k=top_k,
64
- load_in_8bit=True,
65
- )
66
- # Add other LLM models initialization conditions here...
67
- memory = ConversationBufferMemory(
68
- memory_key="chat_history",
69
- output_key='answer',
70
- return_messages=True
71
- )
72
  retriever = vector_db.as_retriever()
73
  qa_chain = ConversationalRetrievalChain.from_llm(
74
  llm,
75
  retriever=retriever,
76
- chain_type="stuff",
77
  memory=memory,
78
  return_source_documents=True,
79
- verbose=False,
80
  )
81
  return qa_chain
82
 
83
- # Function to process uploaded PDFs and initialize the database
84
- def process_documents(list_file_obj, chunk_size, chunk_overlap):
85
- list_file_path = [x.name for x in list_file_obj if x is not None]
86
- collection_name = create_collection_name(list_file_path[0])
87
- doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
88
- vector_db = create_db(doc_splits, collection_name)
89
- return vector_db
 
 
 
 
 
90
 
91
- # Streamlit app
92
  def main():
93
  st.title("PDF-based Chatbot")
94
- st.write("Ask any questions about your PDF documents")
95
 
96
- # Step 1: Upload PDF documents
97
- uploaded_files = st.file_uploader("Upload your PDF documents (single or multiple)", type=["pdf"], accept_multiple_files=True)
98
 
99
- # Step 2: Process documents and initialize vector database
100
  if uploaded_files:
101
  chunk_size = st.slider("Chunk size", min_value=100, max_value=1000, value=600, step=20)
102
  chunk_overlap = st.slider("Chunk overlap", min_value=10, max_value=200, value=40, step=10)
 
103
  if st.button("Generate Vector Database"):
104
- vector_db = process_documents(uploaded_files, chunk_size, chunk_overlap)
105
- st.success("Vector database generated successfully!")
 
 
106
 
107
- # Step 3: Initialize QA chain with selected LLM model
108
- st.header("Initialize Question Answering (QA) Chain")
109
- llm_model = st.selectbox("Choose LLM Model", list_llm_simple)
110
  temperature = st.slider("Temperature", min_value=0.01, max_value=1.0, value=0.7, step=0.1)
111
  max_tokens = st.slider("Max Tokens", min_value=224, max_value=4096, value=1024, step=32)
112
- top_k = st.slider("Top-k Samples", min_value=1, max_value=10, value=3, step=1)
 
113
  if st.button("Initialize QA Chain"):
114
  qa_chain = initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db)
115
- st.success("QA Chain initialized successfully!")
116
 
117
- # Step 4: Chatbot interaction
118
  st.header("Chatbot")
119
- message = st.text_input("Type your message here")
120
  if st.button("Submit"):
121
- response = qa_chain(message)
122
- st.write(f"Chatbot Response: {response['answer']}")
123
 
124
  if __name__ == "__main__":
125
  main()
 
1
  import streamlit as st
2
+ import os
3
+ from pathlib import Path
4
  from langchain_community.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain_community.vectorstores import Chroma
7
  from langchain.chains import ConversationalRetrievalChain
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
 
9
  from langchain_community.llms import HuggingFaceEndpoint
10
+ from langchain.memory import ConversationBufferMemory
 
11
  from unidecode import unidecode
12
+ import chromadb
 
 
 
 
13
  import re
14
 
15
+ list_llm = [
16
+ "mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1",
17
+ "mistralai/Mistral-7B-Instruct-v0.1", "google/gemma-7b-it", "google/gemma-2b-it",
18
+ "HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1",
19
+ "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2",
20
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct",
21
+ "google/flan-t5-xxl"
22
+ ]
23
+
24
  def load_doc(list_file_path, chunk_size, chunk_overlap):
25
  loaders = [PyPDFLoader(x) for x in list_file_path]
26
  pages = []
27
  for loader in loaders:
28
  pages.extend(loader.load())
29
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
 
 
 
30
  doc_splits = text_splitter.split_documents(pages)
31
  return doc_splits
32
 
 
33
  def create_db(splits, collection_name):
34
  embedding = HuggingFaceEmbeddings()
35
  new_client = chromadb.EphemeralClient()
 
37
  documents=splits,
38
  embedding=embedding,
39
  client=new_client,
40
+ collection_name=collection_name
 
41
  )
42
  return vectordb
43
 
 
 
 
 
 
 
 
 
 
 
44
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
45
+ llm = HuggingFaceEndpoint(repo_id=llm_model, temperature=temperature, max_new_tokens=max_tokens, top_k=top_k)
46
+ memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
 
 
 
 
 
 
 
 
 
 
 
 
47
  retriever = vector_db.as_retriever()
48
  qa_chain = ConversationalRetrievalChain.from_llm(
49
  llm,
50
  retriever=retriever,
51
+ chain_type="stuff",
52
  memory=memory,
53
  return_source_documents=True,
54
+ verbose=False
55
  )
56
  return qa_chain
57
 
58
+ def create_collection_name(file_path):
59
+ collection_name = Path(file_path).stem
60
+ collection_name = unidecode(collection_name)
61
+ collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
62
+ collection_name = collection_name[:50]
63
+ if len(collection_name) < 3:
64
+ collection_name = collection_name + 'xyz'
65
+ if not collection_name[0].isalnum():
66
+ collection_name = 'A' + collection_name[1:]
67
+ if not collection_name[-1].isalnum():
68
+ collection_name = collection_name[:-1] + 'Z'
69
+ return collection_name
70
 
 
71
  def main():
72
  st.title("PDF-based Chatbot")
 
73
 
74
+ uploaded_files = st.file_uploader("Upload PDF documents (single or multiple)", type="pdf", accept_multiple_files=True)
 
75
 
 
76
  if uploaded_files:
77
  chunk_size = st.slider("Chunk size", min_value=100, max_value=1000, value=600, step=20)
78
  chunk_overlap = st.slider("Chunk overlap", min_value=10, max_value=200, value=40, step=10)
79
+
80
  if st.button("Generate Vector Database"):
81
+ list_file_path = [file.name for file in uploaded_files]
82
+ collection_name = create_collection_name(list_file_path[0])
83
+ doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
84
+ vector_db = create_db(doc_splits, collection_name)
85
 
86
+ llm_model = st.selectbox("Choose LLM Model", list_llm)
 
 
87
  temperature = st.slider("Temperature", min_value=0.01, max_value=1.0, value=0.7, step=0.1)
88
  max_tokens = st.slider("Max Tokens", min_value=224, max_value=4096, value=1024, step=32)
89
+ top_k = st.slider("Top-K Samples", min_value=1, max_value=10, value=3, step=1)
90
+
91
  if st.button("Initialize QA Chain"):
92
  qa_chain = initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db)
 
93
 
 
94
  st.header("Chatbot")
95
+ message = st.text_input("Type your message")
96
  if st.button("Submit"):
97
+ response = qa_chain({"question": message, "chat_history": []})
98
+ st.write("Assistant:", response["answer"])
99
 
100
  if __name__ == "__main__":
101
  main()