File size: 3,763 Bytes
44d0cf7
 
 
 
 
 
 
 
 
 
 
 
 
bba3f3b
44d0cf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain_core.prompts import PromptTemplate

# Set Paths
DATA_PATH = "dataFolder/"
DB_FAISS_PATH = "/tmp/vectorstore/db_faiss"

# Hugging Face Credentials
HF_TOKEN = os.environ.get("HF_TOKEN")
HUGGINGFACE_REPO_ID = "meta-llama/Llama-3.2-3B-Instruct"

# Step 1: Load PDF Files
def load_pdf_files(data_path):
    loader = DirectoryLoader(data_path, glob="*.pdf", loader_cls=PyPDFLoader)
    documents = loader.load()
    return documents

# Step 2: Create Chunks
def create_chunks(documents):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    text_chunks = text_splitter.split_documents(documents)
    return text_chunks

# Step 3: Generate Embeddings
def get_embedding_model():
    CACHE_DIR = "/tmp/models_cache"
    os.makedirs(CACHE_DIR, exist_ok=True)

    embedding_model = HuggingFaceEmbeddings(
        model_name="rishi002/all-MiniLM-L6-v2",
        cache_folder="/tmp/models_cache"
    )
    return embedding_model

# Step 4: Store Embeddings in FAISS
def store_embeddings(text_chunks, embedding_model, db_path):
    db = FAISS.from_documents(text_chunks, embedding_model)
    db.save_local(db_path)
    return db

# Step 5: Load FAISS Database
def load_faiss_db(db_path, embedding_model):
    return FAISS.load_local(db_path, embedding_model, allow_dangerous_deserialization=True)

# Step 6: Load LLM Model
def load_llm(huggingface_repo_id):
    return HuggingFaceEndpoint(
        repo_id=huggingface_repo_id,
        temperature=0.3,
        model_kwargs={"token": HF_TOKEN, "max_length": 512}
    )

# Step 7: Set Custom Prompt
CUSTOM_PROMPT_TEMPLATE = """
Use the provided context to answer the user's question.
If the answer is unknown, say you don't know. Do not make up information.
Only respond based on the context.

Context: {context}
Question: {question}

Start your answer directly.
"""

def set_custom_prompt(template):
    return PromptTemplate(template=template, input_variables=["context", "question"])

# Step 8: Create Retrieval QA Chain
def create_qa_chain(llm, db):
    return RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=db.as_retriever(search_kwargs={"k": 3}),
        return_source_documents=False,
        chain_type_kwargs={"prompt": set_custom_prompt(CUSTOM_PROMPT_TEMPLATE)}
    )

# Create and load all models and FAISS (for Gradio)
def prepare_qa_system():
    # Load and process PDFs, create FAISS index, etc.
    print("πŸ”„ Loading PDFs...")
    documents = load_pdf_files(DATA_PATH)
    
    print("πŸ“„ Creating Chunks...")
    text_chunks = create_chunks(documents)
    
    print("🧠 Generating Embeddings...")
    embedding_model = get_embedding_model()
    
    print("πŸ’Ύ Storing in FAISS...")
    db = store_embeddings(text_chunks, embedding_model, DB_FAISS_PATH)

    print("πŸ”„ Loading FAISS Database...")
    db = load_faiss_db(DB_FAISS_PATH, embedding_model)

    print("πŸ€– Loading LLM...")
    llm = load_llm(HUGGINGFACE_REPO_ID)

    print("πŸ”— Creating QA Chain...")
    qa_chain = create_qa_chain(llm, db)

    return qa_chain

# Create the QA system and get the chain ready
qa_chain = prepare_qa_system()

# Gradio Interface function
def ask_question(query: str):
    try:
        response = qa_chain.invoke({"query": query})
        return response["result"], [doc.metadata for doc in response["source_documents"]]
    except Exception as e:
        return f"Error: {str(e)}", []