Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
from langchain.chat_models import ChatOpenAI | |
# Streamlit app title | |
st.title("Question Answering with Legal Documents") | |
# Directory containing the PDFs | |
pdf_folder_path = "law" # This is the folder where all PDFs are stored | |
# Load and process all PDFs from the folder, including page numbers | |
def load_pdfs_from_folder(folder_path): | |
all_docs = [] | |
for filename in os.listdir(folder_path): | |
if filename.endswith(".pdf"): | |
pdf_path = os.path.join(folder_path, filename) | |
loader = PyPDFLoader(pdf_path) | |
docs = loader.load() | |
for doc in docs: | |
# Ensure each chunk has metadata for page numbers | |
doc.metadata["source"] = filename | |
all_docs.extend(docs) | |
return all_docs | |
docs = load_pdfs_from_folder(pdf_folder_path) | |
# Split the documents into chunks | |
def split_docs(_docs): | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=150) | |
return text_splitter.split_documents(_docs) | |
splits = split_docs(docs) | |
# Load OpenAI embeddings | |
openai_api_key = st.secrets["openai_api_key"] | |
embedding = OpenAIEmbeddings(openai_api_key=openai_api_key) | |
# Vectorstore setup (Chroma) | |
persist_directory = 'docs/chroma/' | |
vectordb = Chroma.from_documents(documents=splits, embedding=embedding, persist_directory=persist_directory) | |
# Custom PromptTemplate with improvements | |
template = """You are a legal expert and must answer questions using only the provided legal documents. | |
You must use the context below to find the answer. Do not guess or provide any answer that is not based on the documents. | |
If you don't know the answer based on the documents, simply say, "The answer is not available in the provided documents." | |
Keep your answer under 100 words, and always say "Thanks for asking!" at the end. | |
{context} | |
Question: {question} | |
Helpful Answer (based on the documents, 100 words max):""" | |
QA_CHAIN_PROMPT = PromptTemplate.from_template(template) | |
# Improved retriever settings | |
retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={"k": 1}) # Maximal Marginal Relevance for better diversity | |
# Build the QA chain with improved relevance | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=openai_api_key), | |
retriever=retriever, | |
return_source_documents=True, | |
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT} | |
) | |
# Streamlit user input | |
question = st.text_input("Ask a question based on the legal documents:") | |
if st.button("Get Answer"): | |
if question: | |
with st.spinner('Generating answer...'): | |
result = qa_chain({"query": question}) | |
st.write(result["result"]) # Display the concise answer | |
# Display reference to source documents | |
st.subheader("Referenced Source:") | |
for doc in result["source_documents"]: | |
doc_name = os.path.basename(doc.metadata["source"]) | |
page_number = doc.metadata.get('page', 'N/A') # Attempt to get the page number, fallback to N/A | |
st.write(f"Referenced from {doc_name}, page {page_number}") | |
# Optionally show content | |
# st.write(doc.page_content) # Uncomment to show the content of the document | |
else: | |
st.error("Please ask a question.") | |