Spaces:
Running
Running
File size: 8,306 Bytes
1842d77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import os
import tempfile
import streamlit as st
from dotenv import load_dotenv
from pdfminer.high_level import extract_text
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
from langchain_groq import ChatGroq
from langchain.agents import initialize_agent, load_tools
# Check if the secrets file exists and load it
secrets_exists = os.path.exists(os.path.join(os.getcwd(), ".streamlit", "secrets.toml")) or \
os.path.exists(os.path.join(os.path.expanduser("~"), ".streamlit", "secrets.toml"))
if secrets_exists:
load_dotenv(os.path.join(os.getcwd(), ".streamlit", "secrets.toml"))
# Function to extract text from PDFs
def extract_text_from_pdfs(docs):
text = ""
for doc in docs:
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
tmp_file.write(doc.getbuffer())
tmp_file_path = tmp_file.name
extracted_text = extract_text(tmp_file_path)
text += extracted_text
except Exception as e:
st.error(f"Error processing {doc.name}: {e}")
finally:
os.remove(tmp_file_path)
return text
# Function to split text into chunks
def get_text_chunks(raw_text):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50)
chunks = text_splitter.split_text(raw_text)
return chunks
# Function to create FAISS index
def create_faiss_index(text_chunks):
model_name = "BAAI/bge-small-en"
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": True}
embeddings = HuggingFaceBgeEmbeddings(model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs)
vector_store = FAISS.from_texts(text_chunks, embeddings)
return vector_store
# Function to get the conversation chain
def get_conversation_chain(vector_store, groq_api_key):
llm = ChatGroq(
temperature=0.7,
model="llama3-70b-8192",
api_key=groq_api_key,
streaming=True,
verbose=True
)
memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
prompt_template = PromptTemplate(
input_variables=["question"],
template="""You are an AI language model assistant. Your task is to generate 3
different versions of the given user question to retrieve relevant documents from
a vector database. By generating multiple perspectives on the user question, your
goal is to help the user overcome some of the limitations of the distance-based
similarity search. Provide these alternative questions separated by newlines.
Original question: {question}""",
)
llm_chain = LLMChain(llm=llm, prompt=prompt_template)
retriever = MultiQueryRetriever(retriever=vector_store.as_retriever(), llm_chain=llm_chain, num_queries=3)
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
memory=memory
)
return conversation_chain, llm
# Function to get the web agent
def get_web_agent(groq_api_key):
llm = ChatGroq(
temperature=0.7,
model="llama3-70b-8192",
api_key=groq_api_key,
streaming=True,
verbose=True
)
# can create custom tools
tools = load_tools([], llm=llm)
from tools import summarizer_tool
tools.append(summarizer_tool)
additional_tools = load_tools(["llm-math", "google-search"], llm=llm)
tools.extend(additional_tools)
memory = ConversationBufferMemory(memory_key="chat_history")
ZERO_SHOT_REACT_DESCRIPTION = initialize_agent(
agent='zero-shot-react-description',
tools=tools,
llm=llm,
verbose=True,
max_iterations=10,
memory=memory,
handle_parsing_errors=True
)
return ZERO_SHOT_REACT_DESCRIPTION
# Main function
def main():
if "conversation" not in st.session_state:
st.session_state.conversation = None
st.session_state.chat_history = []
st.session_state.vector_store = None
st.set_page_config(page_title="Multi Model Agent", page_icon=":books:")
st.markdown("<h2 style='text-align: center;'>AI Agent 🤖</h2>", unsafe_allow_html=True)
with st.sidebar:
st.markdown('📖 API_KEYS [REPO](https://github.com/ANeuronI/RAG-AGENT)')
st.title("📤 Upload Pdf ")
docs = st.file_uploader(" ", type=["pdf"], accept_multiple_files=True)
file_details = []
if docs is not None:
for doc in docs:
file_details.append({"FileName": doc.name})
with st.expander("Uploaded Files"):
if file_details:
for details in file_details:
st.write(f"File Name: {details['FileName']}")
st.subheader("Start Model🧠")
groq_api_key = os.getenv("GROQ_API_KEY")
if groq_api_key:
st.success('Groq API key already provided!', icon='✅')
else:
groq_api_key = st.text_input('Enter Groq API key:', type='password', key='groq_api_key')
if groq_api_key and (groq_api_key.startswith('gsk_') and len(groq_api_key) == 56):
os.environ['GROQ_API_KEY'] = groq_api_key
st.success('Groq API key provided!', icon='✅')
else:
st.warning('Please enter a valid Groq API key!', icon='⚠️')
if st.button("Start Inference", key="start_inference") and docs:
with st.spinner("Processing..."):
raw_text = extract_text_from_pdfs(docs)
if raw_text:
text_chunks = get_text_chunks(raw_text)
vector_store = create_faiss_index(text_chunks)
st.session_state.vector_store = vector_store
st.write("FAISS Vector Store created successfully.")
st.session_state.conversation, llm = get_conversation_chain(vector_store, groq_api_key)
st.session_state.llm = llm
st.session_state.web_agent = get_web_agent(groq_api_key)
else:
st.error("No text extracted from the documents.")
if st.session_state.conversation:
for message in st.session_state.chat_history:
if message['role'] == 'user':
with st.chat_message("user"):
st.write(message["content"])
else:
with st.chat_message("assistant"):
st.write(message["content"])
input_disabled = groq_api_key is None
if prompt := st.chat_input("Ask your question here..." , disabled=input_disabled):
st.session_state.chat_history.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.write(prompt)
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = st.session_state.conversation({"question": prompt})
if "answer" in response and "I don't know" not in response["answer"]:
st.session_state.chat_history.append({"role": "assistant", "content": response['answer']})
st.write(response['answer'])
else:
with st.spinner("Searching the web..."):
response = st.session_state.web_agent.run(prompt)
st.session_state.chat_history.append({"role": "assistant", "content": response})
st.write(response)
if __name__ == '__main__':
main()
|