Spaces:
Runtime error
Runtime error
Update pages/jarvis.py
Browse files- pages/jarvis.py +44 -68
pages/jarvis.py
CHANGED
@@ -1,37 +1,35 @@
|
|
1 |
import streamlit as st
|
|
|
|
|
2 |
from langchain_community.document_loaders import PyPDFLoader
|
3 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
4 |
from langchain_community.vectorstores import Chroma
|
5 |
from langchain.chains import ConversationalRetrievalChain
|
6 |
-
from langchain_community.embeddings import HuggingFaceEmbeddings
|
7 |
-
from langchain_community.llms import HuggingFacePipeline
|
8 |
-
from langchain.chains import ConversationChain
|
9 |
-
from langchain.memory import ConversationBufferMemory
|
10 |
from langchain_community.llms import HuggingFaceEndpoint
|
11 |
-
from
|
12 |
-
import chromadb
|
13 |
from unidecode import unidecode
|
14 |
-
|
15 |
-
import transformers
|
16 |
-
import torch
|
17 |
-
import tqdm
|
18 |
-
import accelerate
|
19 |
import re
|
20 |
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
def load_doc(list_file_path, chunk_size, chunk_overlap):
|
23 |
loaders = [PyPDFLoader(x) for x in list_file_path]
|
24 |
pages = []
|
25 |
for loader in loaders:
|
26 |
pages.extend(loader.load())
|
27 |
-
text_splitter = RecursiveCharacterTextSplitter(
|
28 |
-
chunk_size=chunk_size,
|
29 |
-
chunk_overlap=chunk_overlap
|
30 |
-
)
|
31 |
doc_splits = text_splitter.split_documents(pages)
|
32 |
return doc_splits
|
33 |
|
34 |
-
# Create vector database
|
35 |
def create_db(splits, collection_name):
|
36 |
embedding = HuggingFaceEmbeddings()
|
37 |
new_client = chromadb.EphemeralClient()
|
@@ -39,87 +37,65 @@ def create_db(splits, collection_name):
|
|
39 |
documents=splits,
|
40 |
embedding=embedding,
|
41 |
client=new_client,
|
42 |
-
collection_name=collection_name
|
43 |
-
# persist_directory=default_persist_directory
|
44 |
)
|
45 |
return vectordb
|
46 |
|
47 |
-
|
48 |
-
# Load vector database
|
49 |
-
def load_db():
|
50 |
-
embedding = HuggingFaceEmbeddings()
|
51 |
-
vectordb = Chroma(
|
52 |
-
# persist_directory=default_persist_directory,
|
53 |
-
embedding_function=embedding)
|
54 |
-
return vectordb
|
55 |
-
|
56 |
-
# Initialize Langchain LLM chain
|
57 |
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
|
58 |
-
|
59 |
-
|
60 |
-
repo_id=llm_model,
|
61 |
-
temperature=temperature,
|
62 |
-
max_new_tokens=max_tokens,
|
63 |
-
top_k=top_k,
|
64 |
-
load_in_8bit=True,
|
65 |
-
)
|
66 |
-
# Add other LLM models initialization conditions here...
|
67 |
-
memory = ConversationBufferMemory(
|
68 |
-
memory_key="chat_history",
|
69 |
-
output_key='answer',
|
70 |
-
return_messages=True
|
71 |
-
)
|
72 |
retriever = vector_db.as_retriever()
|
73 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
74 |
llm,
|
75 |
retriever=retriever,
|
76 |
-
chain_type="stuff",
|
77 |
memory=memory,
|
78 |
return_source_documents=True,
|
79 |
-
verbose=False
|
80 |
)
|
81 |
return qa_chain
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
collection_name =
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
# Streamlit app
|
92 |
def main():
|
93 |
st.title("PDF-based Chatbot")
|
94 |
-
st.write("Ask any questions about your PDF documents")
|
95 |
|
96 |
-
|
97 |
-
uploaded_files = st.file_uploader("Upload your PDF documents (single or multiple)", type=["pdf"], accept_multiple_files=True)
|
98 |
|
99 |
-
# Step 2: Process documents and initialize vector database
|
100 |
if uploaded_files:
|
101 |
chunk_size = st.slider("Chunk size", min_value=100, max_value=1000, value=600, step=20)
|
102 |
chunk_overlap = st.slider("Chunk overlap", min_value=10, max_value=200, value=40, step=10)
|
|
|
103 |
if st.button("Generate Vector Database"):
|
104 |
-
|
105 |
-
|
|
|
|
|
106 |
|
107 |
-
|
108 |
-
st.header("Initialize Question Answering (QA) Chain")
|
109 |
-
llm_model = st.selectbox("Choose LLM Model", list_llm_simple)
|
110 |
temperature = st.slider("Temperature", min_value=0.01, max_value=1.0, value=0.7, step=0.1)
|
111 |
max_tokens = st.slider("Max Tokens", min_value=224, max_value=4096, value=1024, step=32)
|
112 |
-
top_k = st.slider("Top-
|
|
|
113 |
if st.button("Initialize QA Chain"):
|
114 |
qa_chain = initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db)
|
115 |
-
st.success("QA Chain initialized successfully!")
|
116 |
|
117 |
-
# Step 4: Chatbot interaction
|
118 |
st.header("Chatbot")
|
119 |
-
message = st.text_input("Type your message
|
120 |
if st.button("Submit"):
|
121 |
-
response = qa_chain(message)
|
122 |
-
st.write(
|
123 |
|
124 |
if __name__ == "__main__":
|
125 |
main()
|
|
|
1 |
import streamlit as st
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
from langchain_community.document_loaders import PyPDFLoader
|
5 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
6 |
from langchain_community.vectorstores import Chroma
|
7 |
from langchain.chains import ConversationalRetrievalChain
|
8 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
|
|
|
|
|
9 |
from langchain_community.llms import HuggingFaceEndpoint
|
10 |
+
from langchain.memory import ConversationBufferMemory
|
|
|
11 |
from unidecode import unidecode
|
12 |
+
import chromadb
|
|
|
|
|
|
|
|
|
13 |
import re
|
14 |
|
15 |
+
list_llm = [
|
16 |
+
"mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
17 |
+
"mistralai/Mistral-7B-Instruct-v0.1", "google/gemma-7b-it", "google/gemma-2b-it",
|
18 |
+
"HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1",
|
19 |
+
"meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2",
|
20 |
+
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct",
|
21 |
+
"google/flan-t5-xxl"
|
22 |
+
]
|
23 |
+
|
24 |
def load_doc(list_file_path, chunk_size, chunk_overlap):
|
25 |
loaders = [PyPDFLoader(x) for x in list_file_path]
|
26 |
pages = []
|
27 |
for loader in loaders:
|
28 |
pages.extend(loader.load())
|
29 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
|
|
|
|
|
|
30 |
doc_splits = text_splitter.split_documents(pages)
|
31 |
return doc_splits
|
32 |
|
|
|
33 |
def create_db(splits, collection_name):
|
34 |
embedding = HuggingFaceEmbeddings()
|
35 |
new_client = chromadb.EphemeralClient()
|
|
|
37 |
documents=splits,
|
38 |
embedding=embedding,
|
39 |
client=new_client,
|
40 |
+
collection_name=collection_name
|
|
|
41 |
)
|
42 |
return vectordb
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
|
45 |
+
llm = HuggingFaceEndpoint(repo_id=llm_model, temperature=temperature, max_new_tokens=max_tokens, top_k=top_k)
|
46 |
+
memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
retriever = vector_db.as_retriever()
|
48 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
49 |
llm,
|
50 |
retriever=retriever,
|
51 |
+
chain_type="stuff",
|
52 |
memory=memory,
|
53 |
return_source_documents=True,
|
54 |
+
verbose=False
|
55 |
)
|
56 |
return qa_chain
|
57 |
|
58 |
+
def create_collection_name(file_path):
|
59 |
+
collection_name = Path(file_path).stem
|
60 |
+
collection_name = unidecode(collection_name)
|
61 |
+
collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
|
62 |
+
collection_name = collection_name[:50]
|
63 |
+
if len(collection_name) < 3:
|
64 |
+
collection_name = collection_name + 'xyz'
|
65 |
+
if not collection_name[0].isalnum():
|
66 |
+
collection_name = 'A' + collection_name[1:]
|
67 |
+
if not collection_name[-1].isalnum():
|
68 |
+
collection_name = collection_name[:-1] + 'Z'
|
69 |
+
return collection_name
|
70 |
|
|
|
71 |
def main():
|
72 |
st.title("PDF-based Chatbot")
|
|
|
73 |
|
74 |
+
uploaded_files = st.file_uploader("Upload PDF documents (single or multiple)", type="pdf", accept_multiple_files=True)
|
|
|
75 |
|
|
|
76 |
if uploaded_files:
|
77 |
chunk_size = st.slider("Chunk size", min_value=100, max_value=1000, value=600, step=20)
|
78 |
chunk_overlap = st.slider("Chunk overlap", min_value=10, max_value=200, value=40, step=10)
|
79 |
+
|
80 |
if st.button("Generate Vector Database"):
|
81 |
+
list_file_path = [file.name for file in uploaded_files]
|
82 |
+
collection_name = create_collection_name(list_file_path[0])
|
83 |
+
doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
|
84 |
+
vector_db = create_db(doc_splits, collection_name)
|
85 |
|
86 |
+
llm_model = st.selectbox("Choose LLM Model", list_llm)
|
|
|
|
|
87 |
temperature = st.slider("Temperature", min_value=0.01, max_value=1.0, value=0.7, step=0.1)
|
88 |
max_tokens = st.slider("Max Tokens", min_value=224, max_value=4096, value=1024, step=32)
|
89 |
+
top_k = st.slider("Top-K Samples", min_value=1, max_value=10, value=3, step=1)
|
90 |
+
|
91 |
if st.button("Initialize QA Chain"):
|
92 |
qa_chain = initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db)
|
|
|
93 |
|
|
|
94 |
st.header("Chatbot")
|
95 |
+
message = st.text_input("Type your message")
|
96 |
if st.button("Submit"):
|
97 |
+
response = qa_chain({"question": message, "chat_history": []})
|
98 |
+
st.write("Assistant:", response["answer"])
|
99 |
|
100 |
if __name__ == "__main__":
|
101 |
main()
|