import os import json import logging import shutil import gradio as gr from typing import List from tempfile import NamedTemporaryFile from huggingface_hub import InferenceClient from langchain_community.document_loaders import PyPDFLoader from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.docstore.document import Document # Setup logging logging.basicConfig(level=logging.INFO) # Constants DOCUMENTS_FILE = "uploaded_documents.json" DEFAULT_MODEL = "@cf/meta/llama-2-7b-chat" HF_TOKEN = os.getenv("HF_API_TOKEN") # Make sure to set this environment variable EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2" def get_embeddings(): return HuggingFaceEmbeddings( model_name=EMBED_MODEL, model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True} ) def load_documents(): if os.path.exists(DOCUMENTS_FILE): with open(DOCUMENTS_FILE, "r") as f: return json.load(f) return [] def save_documents(documents): with open(DOCUMENTS_FILE, "w") as f: json.dump(documents, f) def load_document(file: NamedTemporaryFile) -> List[Document]: """Loads and splits the document into pages using PyPDF.""" loader = PyPDFLoader(file.name) return loader.load_and_split() def process_uploaded_files(files): if not files: return "Please upload at least one file.", [] files_list = [files] if not isinstance(files, list) else files embed = get_embeddings() uploaded_documents = load_documents() total_chunks = 0 all_data = [] for file in files_list: try: data = load_document(file) if not data: continue all_data.extend(data) total_chunks += len(data) if not any(doc["name"] == file.name for doc in uploaded_documents): uploaded_documents.append({"name": file.name, "selected": True}) except Exception as e: logging.error(f"Error processing file {file.name}: {str(e)}") if not all_data: return "No valid data could be extracted from the uploaded files.", [] try: if os.path.exists("faiss_database"): database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True) database.add_documents(all_data) else: database = FAISS.from_documents(all_data, embed) database.save_local("faiss_database") save_documents(uploaded_documents) return f"Vector store updated successfully. Processed {total_chunks} chunks.", [doc["name"] for doc in uploaded_documents] except Exception as e: return f"Error updating vector store: {str(e)}", [] def delete_documents(selected_docs): if not selected_docs: return "No documents selected for deletion.", [] uploaded_documents = load_documents() embed = get_embeddings() if os.path.exists("faiss_database"): database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True) docs_to_keep = [] for doc in database.docstore._dict.values(): if doc.metadata.get("source") not in selected_docs: docs_to_keep.append(doc) if not docs_to_keep: shutil.rmtree("faiss_database") else: new_database = FAISS.from_documents(docs_to_keep, embed) new_database.save_local("faiss_database") uploaded_documents = [doc for doc in uploaded_documents if doc["name"] not in selected_docs] save_documents(uploaded_documents) remaining_docs = [doc["name"] for doc in uploaded_documents] return f"Deleted documents: {', '.join(selected_docs)}", remaining_docs return "No documents to delete.", [] def get_response(query, temperature=0.2): if not query.strip(): return "Please enter a question." uploaded_documents = load_documents() selected_docs = [doc["name"] for doc in uploaded_documents if doc["selected"]] if not selected_docs: return "Please select at least one document to search through." embed = get_embeddings() if not os.path.exists("faiss_database"): return "No documents available. Please upload PDF documents first." database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True) # Filter documents filtered_docs = [] for doc in database.docstore._dict.values(): if isinstance(doc, Document) and doc.metadata.get("source") in selected_docs: filtered_docs.append(doc) if not filtered_docs: return "No relevant information found in the selected documents." filtered_db = FAISS.from_documents(filtered_docs, embed) retriever = filtered_db.as_retriever(search_kwargs={"k": 5}) relevant_docs = retriever.get_relevant_documents(query) context_str = "\n".join([doc.page_content for doc in relevant_docs]) messages = [ {"role": "system", "content": "You are a helpful assistant that provides accurate answers based on the given context."}, {"role": "user", "content": f"Context:\n{context_str}\n\nQuestion: {query}\n\nProvide a comprehensive answer based only on the given context."} ] client = InferenceClient(DEFAULT_MODEL, token=HF_TOKEN) try: response = client.chat_completion( messages=messages, max_tokens=1000, temperature=temperature, top_p=0.9, ) return response.choices[0].message.content except Exception as e: return f"Error generating response: {str(e)}" def create_interface(): with gr.Blocks(title="PDF Question Answering System") as app: gr.Markdown("# PDF Question Answering System") with gr.Row(): with gr.Column(): files = gr.File( label="Upload PDF Documents", file_types=[".pdf"], file_count="multiple" ) upload_button = gr.Button("Upload and Process") with gr.Column(): doc_status = gr.Textbox(label="Status", interactive=False) doc_list = gr.Checkboxgroup( label="Available Documents", choices=[], interactive=True ) delete_button = gr.Button("Delete Selected Documents") with gr.Row(): with gr.Column(): question = gr.Textbox( label="Ask a question about the documents", placeholder="Enter your question here..." ) temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.2, step=0.1, label="Temperature (Higher values make the output more random)" ) submit_button = gr.Button("Submit Question") with gr.Column(): answer = gr.Textbox( label="Answer", interactive=False, lines=10 ) # Event handlers upload_button.click( fn=process_uploaded_files, inputs=[files], outputs=[doc_status, doc_list] ) delete_button.click( fn=delete_documents, inputs=[doc_list], outputs=[doc_status, doc_list] ) submit_button.click( fn=get_response, inputs=[question, temperature], outputs=[answer] ) # Add keyboard shortcut for submitting questions question.submit( fn=get_response, inputs=[question, temperature], outputs=[answer] ) return app if __name__ == "__main__": app = create_interface() app.launch( server_name="0.0.0.0", # Makes the app accessible from other machines server_port=7860, # Specify port share=True # Creates a public URL )