Spaces:
Runtime error
Runtime error
Commit
·
abb6f94
1
Parent(s):
66bdce2
application addd
Browse files- Dockerfile +10 -0
- app/__init__.py +5 -0
- app/__pycache__/__init__.cpython-310.pyc +0 -0
- app/__pycache__/logging_config.cpython-310.pyc +0 -0
- app/__pycache__/main.cpython-310.pyc +0 -0
- app/__pycache__/settings.cpython-310.pyc +0 -0
- app/api/__init__.py +2 -0
- app/api/__pycache__/__init__.cpython-310.pyc +0 -0
- app/api/__pycache__/answer.cpython-310.pyc +0 -0
- app/api/__pycache__/upload.cpython-310.pyc +0 -0
- app/api/answer.py +63 -0
- app/api/upload.py +173 -0
- app/data_pipeline/__init__.py +2 -0
- app/data_pipeline/__pycache__/__init__.cpython-310.pyc +0 -0
- app/data_pipeline/__pycache__/data_loader.cpython-310.pyc +0 -0
- app/data_pipeline/__pycache__/embedding_manager.cpython-310.pyc +0 -0
- app/data_pipeline/data_loader.py +153 -0
- app/data_pipeline/embedding_manager.py +109 -0
- app/logging_config.py +45 -0
- app/main.py +32 -0
- app/rag_pipeline/__init__.py +3 -0
- app/rag_pipeline/__pycache__/__init__.cpython-310.pyc +0 -0
- app/rag_pipeline/__pycache__/chroma_client.cpython-310.pyc +0 -0
- app/rag_pipeline/__pycache__/model_initializer.cpython-310.pyc +0 -0
- app/rag_pipeline/__pycache__/model_loader.cpython-310.pyc +0 -0
- app/rag_pipeline/__pycache__/prompt_utils.cpython-310.pyc +0 -0
- app/rag_pipeline/__pycache__/retriever_chain.cpython-310.pyc +0 -0
- app/rag_pipeline/__pycache__/retriver_chain.cpython-310.pyc +0 -0
- app/rag_pipeline/chroma_client.py +15 -0
- app/rag_pipeline/model_initializer.py +50 -0
- app/rag_pipeline/model_loader.py +148 -0
- app/rag_pipeline/prompt_utils.py +31 -0
- app/rag_pipeline/retriever_chain.py +136 -0
- app/settings.py +33 -0
- requirements.txt +0 -0
Dockerfile
CHANGED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
COPY ./requirements.txt requirements.txt
|
6 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
7 |
+
|
8 |
+
COPY app/ app/
|
9 |
+
|
10 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
app/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from main import create_app
|
2 |
+
# from api.answer import answer_router
|
3 |
+
|
4 |
+
|
5 |
+
# app = create_app()
|
app/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (125 Bytes). View file
|
|
app/__pycache__/logging_config.cpython-310.pyc
ADDED
Binary file (997 Bytes). View file
|
|
app/__pycache__/main.cpython-310.pyc
ADDED
Binary file (852 Bytes). View file
|
|
app/__pycache__/settings.cpython-310.pyc
ADDED
Binary file (1.16 kB). View file
|
|
app/api/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from app.api.answer import answer_router
|
2 |
+
from app.api.upload import upload_router
|
app/api/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (233 Bytes). View file
|
|
app/api/__pycache__/answer.cpython-310.pyc
ADDED
Binary file (1.97 kB). View file
|
|
app/api/__pycache__/upload.cpython-310.pyc
ADDED
Binary file (3.57 kB). View file
|
|
app/api/answer.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, Request,Body, HTTPException
|
2 |
+
import logging
|
3 |
+
from typing import Dict
|
4 |
+
|
5 |
+
from app.rag_pipeline.model_initializer import initialize_models
|
6 |
+
from app.rag_pipeline.retriever_chain import RetrieverChain
|
7 |
+
from app.settings import Config
|
8 |
+
|
9 |
+
answer_router = APIRouter()
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
import warnings
|
13 |
+
warnings.filterwarnings("ignore")
|
14 |
+
conf = Config()
|
15 |
+
|
16 |
+
OPENAI_API_KEY = conf.API_KEY
|
17 |
+
MODEL_ID = conf.MODEL_ID
|
18 |
+
MODEL_BASENAME = conf.MODEL_BASENAME
|
19 |
+
COLLECTION_NAME = conf.COLLECTION_NAME
|
20 |
+
PERSIST_DIRECTORY = conf.PERSIST_DIRECTORY
|
21 |
+
|
22 |
+
# print(OPENAI_API_KEY, MODEL_ID, MODEL_BASENAME, PERSIST_DIRECTORY, COLLECTION_NAME)
|
23 |
+
|
24 |
+
embedding_model, llm_model = initialize_models(OPENAI_API_KEY,model_id=MODEL_ID, model_basename=MODEL_BASENAME)
|
25 |
+
|
26 |
+
def validate_question(data: Dict) -> str:
|
27 |
+
"""Extract and validate the 'question' field from the incoming data."""
|
28 |
+
question = data.get("question")
|
29 |
+
if not question or not isinstance(question, str) or not question.strip():
|
30 |
+
logger.warning("Received invalid question input.")
|
31 |
+
raise HTTPException(status_code=400, detail="Question must be a non-empty string.")
|
32 |
+
return question
|
33 |
+
|
34 |
+
@answer_router.post('/answer', response_model=dict)
|
35 |
+
async def generate_answer(data: Dict = Body(...)) -> Dict:
|
36 |
+
try:
|
37 |
+
# Validate and extract the question
|
38 |
+
question = validate_question(data)
|
39 |
+
|
40 |
+
# Log incoming question
|
41 |
+
logger.info(f"Received question: {question}")
|
42 |
+
|
43 |
+
# Generate the answer
|
44 |
+
|
45 |
+
retriever_qa = RetrieverChain(
|
46 |
+
collection_name=COLLECTION_NAME, embedding_function=embedding_model, persist_directory=PERSIST_DIRECTORY)
|
47 |
+
answer = retriever_qa.get_response(user_input = question, llm= llm_model)
|
48 |
+
|
49 |
+
# answer = f"Generated answer for: {question}"
|
50 |
+
|
51 |
+
# Log generated answer
|
52 |
+
logger.info(f"Generated answer: {answer}")
|
53 |
+
|
54 |
+
return {"answer": answer}
|
55 |
+
|
56 |
+
except HTTPException as http_exc:
|
57 |
+
logger.error(f"HTTP error: {http_exc.detail}")
|
58 |
+
raise http_exc # Re-raise the HTTPException to return the error response
|
59 |
+
except Exception as e:
|
60 |
+
logger.error(f"Unexpected error: {str(e)}")
|
61 |
+
raise HTTPException(status_code=500, detail="An internal server error occurred.")
|
62 |
+
|
63 |
+
|
app/api/upload.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, APIRouter, UploadFile, HTTPException
|
2 |
+
from fastapi import File, UploadFile
|
3 |
+
from fastapi.staticfiles import StaticFiles
|
4 |
+
from fastapi.responses import JSONResponse
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import List
|
7 |
+
import os
|
8 |
+
import shutil
|
9 |
+
import logging
|
10 |
+
|
11 |
+
|
12 |
+
from app.data_pipeline.data_loader import DocumentLoader
|
13 |
+
from app.data_pipeline.embedding_manager import split_text,initialize_embedding_model,create_and_store_embeddings
|
14 |
+
import warnings
|
15 |
+
warnings.filterwarnings("ignore")
|
16 |
+
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
from app.settings import Config
|
21 |
+
conf = Config()
|
22 |
+
|
23 |
+
upload_router = APIRouter()
|
24 |
+
|
25 |
+
UPLOAD_DIR = conf.UPLOAD_DIR
|
26 |
+
|
27 |
+
COLLECTION_NAME = conf.COLLECTION_NAME
|
28 |
+
PERSIST_DIRECTORY = conf.PERSIST_DIRECTORY
|
29 |
+
|
30 |
+
|
31 |
+
# Type of files allowed to be uploaded
|
32 |
+
def is_allowed_file(filename):
|
33 |
+
allowed_extensions = {"pdf", "csv", "doc", "docx", "txt", "xlsx", "xls"}
|
34 |
+
return "." in filename and filename.rsplit(".", 1)[1].lower() in allowed_extensions
|
35 |
+
|
36 |
+
|
37 |
+
def empty_folder(folder_path):
|
38 |
+
# Check if the folder exists
|
39 |
+
if os.path.exists(folder_path):
|
40 |
+
# Iterate through all items in the folder
|
41 |
+
for item in os.listdir(folder_path):
|
42 |
+
item_path = os.path.join(folder_path, item)
|
43 |
+
# Remove files and folders
|
44 |
+
if os.path.isfile(item_path) or os.path.islink(item_path):
|
45 |
+
os.remove(item_path)
|
46 |
+
elif os.path.isdir(item_path):
|
47 |
+
shutil.rmtree(item_path)
|
48 |
+
print(f"The folder '{folder_path}' has been emptied.")
|
49 |
+
else:
|
50 |
+
print(f"The folder '{folder_path}' does not exist.")
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
@upload_router.post("/upload")
|
56 |
+
async def upload_files(files: List[UploadFile] = File(...)):
|
57 |
+
try:
|
58 |
+
# Empty the upload directory
|
59 |
+
empty_folder(UPLOAD_DIR)
|
60 |
+
logger.info(f"{UPLOAD_DIR} is now empty.")
|
61 |
+
|
62 |
+
# Check if UPLOAD_DIR exists
|
63 |
+
if not os.path.exists(UPLOAD_DIR):
|
64 |
+
logger.error(f"Upload directory '{UPLOAD_DIR}' does not exist.")
|
65 |
+
return JSONResponse(content={"error": f"Folder '{UPLOAD_DIR}' does not exist"}, status_code=404)
|
66 |
+
|
67 |
+
# Save uploaded files
|
68 |
+
for uploaded_file in files:
|
69 |
+
if not is_allowed_file(uploaded_file.filename):
|
70 |
+
logger.error(f"File type of '{uploaded_file.filename}' not allowed.")
|
71 |
+
return JSONResponse(content={"error": "File type not allowed"}, status_code=400)
|
72 |
+
|
73 |
+
file_path = os.path.join(UPLOAD_DIR, uploaded_file.filename)
|
74 |
+
with open(file_path, "wb") as buffer:
|
75 |
+
buffer.write(uploaded_file.file.read())
|
76 |
+
logger.info(f"File '{uploaded_file.filename}' uploaded successfully.")
|
77 |
+
|
78 |
+
# Load documents from the upload directory
|
79 |
+
try:
|
80 |
+
document_loader = DocumentLoader(UPLOAD_DIR)
|
81 |
+
documents = document_loader.load_all_documents()
|
82 |
+
logger.info(f"Loaded {len(documents)} documents.")
|
83 |
+
except Exception as e:
|
84 |
+
logger.error(f"Error loading documents: {e}")
|
85 |
+
return JSONResponse(content={"error": "Failed to load documents"}, status_code=500)
|
86 |
+
|
87 |
+
# Process documents into chunks for embedding
|
88 |
+
try:
|
89 |
+
chunks = split_text(documents)
|
90 |
+
logger.info(f"Processed {len(chunks)} chunks for embedding.")
|
91 |
+
except Exception as e:
|
92 |
+
logger.error(f"Error processing documents: {e}")
|
93 |
+
return JSONResponse(content={"error": "Failed to process documents"}, status_code=500)
|
94 |
+
|
95 |
+
# Initialize the embedding model
|
96 |
+
try:
|
97 |
+
embedding_function = initialize_embedding_model()
|
98 |
+
except Exception as e:
|
99 |
+
logger.error(f"Error initializing embedding model: {e}")
|
100 |
+
return JSONResponse(content={"error": "Failed to initialize embedding model"}, status_code=500)
|
101 |
+
|
102 |
+
# Create and store embeddings
|
103 |
+
try:
|
104 |
+
create_and_store_embeddings(chunks, COLLECTION_NAME, embedding_function, PERSIST_DIRECTORY)
|
105 |
+
logger.info("Embeddings created and stored successfully.")
|
106 |
+
except Exception as e:
|
107 |
+
logger.error(f"Error creating or storing embeddings: {e}")
|
108 |
+
return JSONResponse(content={"error": "Failed to create and store embeddings"}, status_code=500)
|
109 |
+
|
110 |
+
# Return success message if everything is successful
|
111 |
+
return JSONResponse(content={"message": "Documents successfully loaded and processed."})
|
112 |
+
|
113 |
+
except Exception as e:
|
114 |
+
logger.error(f"Unexpected error in upload_files endpoint: {e}")
|
115 |
+
raise HTTPException(status_code=500, detail="Internal server error.")
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
# @upload_router.post("/upload")
|
131 |
+
# async def upload_files(files: List[UploadFile] = File(...)):
|
132 |
+
|
133 |
+
# empty_folder(UPLOAD_DIR)
|
134 |
+
# logger.info(f" {UPLOAD_DIR} is empty Now")
|
135 |
+
|
136 |
+
# if not os.path.exists(UPLOAD_DIR):
|
137 |
+
# logger.error(f"{UPLOAD_DIR}' does not exist")
|
138 |
+
# return JSONResponse(content={"error": f"Folder '{UPLOAD_DIR}' does not exist"}, status_code=404)
|
139 |
+
|
140 |
+
# for uploaded_file in files:
|
141 |
+
# if not is_allowed_file(uploaded_file.filename):
|
142 |
+
# logger.error(f"File type not allowed")
|
143 |
+
# return JSONResponse(content={"error": "File type not allowed"}, status_code=400)
|
144 |
+
|
145 |
+
# file_path = os.path.join(UPLOAD_DIR, uploaded_file.filename)
|
146 |
+
# with open(file_path, "wb") as buffer:
|
147 |
+
# buffer.write(uploaded_file.file.read())
|
148 |
+
|
149 |
+
# logger.info(f"Files uploaded successfully")
|
150 |
+
|
151 |
+
# try:
|
152 |
+
# document_loader = DocumentLoader(UPLOAD_DIR)
|
153 |
+
# documents = document_loader.load_all_documents()
|
154 |
+
# logger.info(f"Loaded {len(documents)} documents.")
|
155 |
+
# except Exception as e:
|
156 |
+
# logger.error(f"Error loading documents: {e}")
|
157 |
+
# return
|
158 |
+
|
159 |
+
# try:
|
160 |
+
# chunks = split_text(documents)
|
161 |
+
# logger.info(f"Processed {len(chunks)} chunks for embedding.", )
|
162 |
+
# except Exception as e:
|
163 |
+
# logger.error(f"Error processing documents: {e}")
|
164 |
+
# return
|
165 |
+
|
166 |
+
# try:
|
167 |
+
# embedding_function = initialize_embedding_model()
|
168 |
+
# except Exception:
|
169 |
+
# return # Stop execution if embedding model fails
|
170 |
+
|
171 |
+
# create_and_store_embeddings(chunks, COLLECTION_NAME, embedding_function, PERSIST_DIRECTORY)
|
172 |
+
# logger.info(f'Documents Successfully loades')
|
173 |
+
# return JSONResponse(content={"message": "Documents Successfully loades"})
|
app/data_pipeline/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from app.data_pipeline.data_loader import DocumentLoader
|
2 |
+
from app.data_pipeline.embedding_manager import *
|
app/data_pipeline/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (259 Bytes). View file
|
|
app/data_pipeline/__pycache__/data_loader.cpython-310.pyc
ADDED
Binary file (3.84 kB). View file
|
|
app/data_pipeline/__pycache__/embedding_manager.cpython-310.pyc
ADDED
Binary file (2.71 kB). View file
|
|
app/data_pipeline/data_loader.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from langchain_community.document_loaders.csv_loader import CSVLoader
|
3 |
+
from langchain_community.document_loaders.csv_loader import UnstructuredCSVLoader
|
4 |
+
from langchain_community.document_loaders.excel import UnstructuredExcelLoader
|
5 |
+
from langchain_community.document_loaders import PDFMinerLoader
|
6 |
+
from langchain_community.document_loaders import TextLoader
|
7 |
+
from langchain_community.document_loaders import Docx2txtLoader
|
8 |
+
from langchain.docstore.document import Document
|
9 |
+
|
10 |
+
import logging
|
11 |
+
|
12 |
+
INGEST_THREADS = os.cpu_count() or 8
|
13 |
+
|
14 |
+
DOCUMENT_MAP = {
|
15 |
+
".txt": TextLoader,
|
16 |
+
".md": TextLoader,
|
17 |
+
".pdf": PDFMinerLoader,
|
18 |
+
".csv": CSVLoader,
|
19 |
+
".csv": UnstructuredCSVLoader,
|
20 |
+
".xls": UnstructuredExcelLoader,
|
21 |
+
".xlsx": UnstructuredExcelLoader,
|
22 |
+
".docx": Docx2txtLoader
|
23 |
+
# Add additional file types here if necessary
|
24 |
+
}
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
class DocumentLoader():
|
29 |
+
def __init__(self, source_dir: str):
|
30 |
+
"""
|
31 |
+
Initializes the loader with the directory path from which to load documents.
|
32 |
+
"""
|
33 |
+
self.source_dir = source_dir
|
34 |
+
logger.info(f"DocumentLoader initialized with source directory: {self.source_dir}")
|
35 |
+
|
36 |
+
def load_single_document(self, file_path: str):
|
37 |
+
"""
|
38 |
+
Loads a single document based on its file extension using the appropriate loader.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
file_path (str): Path to the document file.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
List[Document]: Loaded document(s) as LangChain Document instances.
|
45 |
+
"""
|
46 |
+
|
47 |
+
file_extension = os.path.splitext(file_path)[1]
|
48 |
+
loader_class = DOCUMENT_MAP.get(file_extension)
|
49 |
+
|
50 |
+
if loader_class:
|
51 |
+
loader = loader_class(file_path)
|
52 |
+
logger.info(f"Loading document: {file_path}")
|
53 |
+
try:
|
54 |
+
documents = loader.load()
|
55 |
+
logger.info(f"Successfully loaded document: {file_path}")
|
56 |
+
return documents
|
57 |
+
except Exception as e:
|
58 |
+
logger.error(f"Error loading document {file_path}: {e}", exc_info=True)
|
59 |
+
raise
|
60 |
+
else:
|
61 |
+
logger.warning(f"Unsupported document type for file: {file_path}")
|
62 |
+
raise ValueError(f"Unsupported document type: {file_extension}")
|
63 |
+
|
64 |
+
|
65 |
+
def load_all_documents(self) -> list[Document]:
|
66 |
+
"""
|
67 |
+
Loads all documents from the source directory, including documents in subdirectories.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
List[Document]: List of all loaded documents from the source directory.
|
71 |
+
"""
|
72 |
+
paths = self._gather_file_paths() # Gather file paths of documents to load
|
73 |
+
all_docs = []
|
74 |
+
|
75 |
+
logger.info(f"Loading all documents from directory: {self.source_dir}")
|
76 |
+
|
77 |
+
# # Load each document sequentially
|
78 |
+
# for file_path in paths:
|
79 |
+
# documents = self.load_single_document(file_path)
|
80 |
+
# all_docs.extend(documents) # Append loaded documents to the result list
|
81 |
+
|
82 |
+
# # return all_docs
|
83 |
+
|
84 |
+
|
85 |
+
for file_path in paths:
|
86 |
+
try:
|
87 |
+
documents = self.load_single_document(file_path)
|
88 |
+
all_docs.extend(documents) # Append loaded documents to the result list
|
89 |
+
except ValueError as e:
|
90 |
+
logger.error(f"Skipping file {file_path}: {e}")
|
91 |
+
except Exception as e:
|
92 |
+
logger.error(f"An unexpected error occurred while loading {file_path}: {e}", exc_info=True)
|
93 |
+
|
94 |
+
logger.info(f"Finished loading documents. Total documents loaded: {len(all_docs)}")
|
95 |
+
return all_docs
|
96 |
+
|
97 |
+
def _gather_file_paths(self):
|
98 |
+
"""
|
99 |
+
Walks through the source directory and gathers file paths of documents
|
100 |
+
that match the supported file types in DOCUMENT_MAP.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
List[str]: List of file paths for documents to load.
|
104 |
+
"""
|
105 |
+
file_paths = []
|
106 |
+
logger.debug(f"Scanning for files in directory: {self.source_dir}")
|
107 |
+
for root, _, files in os.walk(self.source_dir):
|
108 |
+
for file_name in files:
|
109 |
+
file_extension = os.path.splitext(file_name)[1]
|
110 |
+
if file_extension in DOCUMENT_MAP:
|
111 |
+
full_path = os.path.join(root, file_name)
|
112 |
+
file_paths.append(full_path)
|
113 |
+
logger.debug(f"Found document: {full_path}")
|
114 |
+
|
115 |
+
logger.info(f"Total files found for loading: {len(file_paths)}")
|
116 |
+
return file_paths
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
# if __name__ == "__main__":
|
121 |
+
# source_directory = os.path.join(os.path.dirname(__file__),'..','Data')
|
122 |
+
# document_loader = DocumentLoader(source_directory)
|
123 |
+
|
124 |
+
# documents = document_loader.load_all_documents()
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
# from langchain_community.embeddings import OpenAIEmbeddings
|
148 |
+
# from langchain_community.vectorstores import FAISS
|
149 |
+
# directory_path = os.path.join(os.path.dirname(__file__),'..','Data')
|
150 |
+
# documents = load_documents(directory_path)
|
151 |
+
# print(documents)
|
152 |
+
|
153 |
+
# print(os.path.join(os.path.dirname(__file__),'..','Data'))
|
app/data_pipeline/embedding_manager.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from typing import List
|
4 |
+
from langchain_chroma import Chroma
|
5 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
6 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
7 |
+
from langchain_openai import OpenAIEmbeddings
|
8 |
+
|
9 |
+
|
10 |
+
from app.settings import Config
|
11 |
+
|
12 |
+
conf = Config()
|
13 |
+
|
14 |
+
OPENAI_API_KEY = conf.API_KEY
|
15 |
+
PERSIST_DIRECTORY = conf.PERSIST_DIRECTORY
|
16 |
+
COLLECTION_NAME = conf.COLLECTION_NAME
|
17 |
+
|
18 |
+
|
19 |
+
# Set up logging
|
20 |
+
import logging
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
def initialize_embedding_model():
|
25 |
+
"""Initialize the embedding model based on the availability of the OpenAI API key."""
|
26 |
+
try:
|
27 |
+
if OPENAI_API_KEY:
|
28 |
+
logger.info("Using OpenAI embedding model.")
|
29 |
+
embedding_model = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
|
30 |
+
else:
|
31 |
+
logger.info(f"Using Hugging Face embedding model.")
|
32 |
+
embedding_model = HuggingFaceEmbeddings(
|
33 |
+
model_name=conf.MODEL_NAME,
|
34 |
+
model_kwargs=conf.MODEL_KWARGS,
|
35 |
+
encode_kwargs=conf.ENCODE_KWARGS
|
36 |
+
)
|
37 |
+
return embedding_model
|
38 |
+
except Exception as e:
|
39 |
+
logger.error(f"Error initializing embedding model: {e}")
|
40 |
+
raise
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
def split_text(documents: List[str]) -> List[str]:
|
45 |
+
"""Split documents into smaller chunks."""
|
46 |
+
try:
|
47 |
+
logger.info(f"Splitting documents into chunks...")
|
48 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=conf.CHUNK_SIZE, chunk_overlap=conf.CHUNK_OVERLAP)
|
49 |
+
chunks = text_splitter.split_documents(documents)
|
50 |
+
logger.info(f"Document splitting completed.")
|
51 |
+
return chunks
|
52 |
+
except Exception as e:
|
53 |
+
logger.error(f"Error splitting text: {e}")
|
54 |
+
raise
|
55 |
+
|
56 |
+
def get_chroma_client(collection_name: str, embedding_function, persist_directory: str):
|
57 |
+
"""Initialize and return a Chroma client for a specific collection."""
|
58 |
+
try:
|
59 |
+
logger.info(f"Creating Chroma client for collection: {collection_name}")
|
60 |
+
return Chroma(
|
61 |
+
collection_name=collection_name,
|
62 |
+
embedding_function=embedding_function,
|
63 |
+
persist_directory=persist_directory
|
64 |
+
)
|
65 |
+
except Exception as e:
|
66 |
+
logger.error(f"Error creating Chroma client: {e}")
|
67 |
+
raise
|
68 |
+
|
69 |
+
def create_and_store_embeddings(chunks: List[str], collection_name: str, embedding_function, persist_directory: str):
|
70 |
+
"""Create and store embeddings for document chunks."""
|
71 |
+
try:
|
72 |
+
vector_db = get_chroma_client(collection_name, embedding_function, persist_directory)
|
73 |
+
vector_db.add_documents(chunks)
|
74 |
+
logger.info(f"Embeddings created for collection {collection_name} and saved to {persist_directory}.")
|
75 |
+
except Exception as e:
|
76 |
+
logger.error("Error creating and storing embeddings: {e}")
|
77 |
+
raise
|
78 |
+
|
79 |
+
# def main():
|
80 |
+
# source_directory = conf.DATA_DIRECTORY
|
81 |
+
# document_loader = DocumentLoader(source_directory)
|
82 |
+
# try:
|
83 |
+
# documents = document_loader.load_all_documents()
|
84 |
+
# logger.info(f"Loaded {len(documents)} documents.")
|
85 |
+
# except Exception as e:
|
86 |
+
# logger.error(f"Error loading documents: {e}")
|
87 |
+
# return
|
88 |
+
|
89 |
+
# # Split documents into chunks
|
90 |
+
# try:
|
91 |
+
# chunks = split_text(documents)
|
92 |
+
# logger.info(f"Processed {len(chunks)} chunks for embedding.", )
|
93 |
+
# except Exception as e:
|
94 |
+
# logger.error(f"Error processing documents: {e}")
|
95 |
+
# return
|
96 |
+
|
97 |
+
# # Initialize embedding model
|
98 |
+
# try:
|
99 |
+
# embedding_function = initialize_embedding_model()
|
100 |
+
# except Exception:
|
101 |
+
# return # Stop execution if embedding model fails
|
102 |
+
|
103 |
+
# # Create and store embeddings
|
104 |
+
# create_and_store_embeddings(chunks, COLLECTION_NAME, embedding_function, PERSIST_DIRECTORY)
|
105 |
+
|
106 |
+
# if __name__ == "__main__":
|
107 |
+
# main()
|
108 |
+
|
109 |
+
|
app/logging_config.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from logging.handlers import RotatingFileHandler
|
3 |
+
import os
|
4 |
+
from app.settings import Config
|
5 |
+
|
6 |
+
# Define whether logging should be enabled
|
7 |
+
is_logging = True # Set to False to disable logging
|
8 |
+
|
9 |
+
conf = Config()
|
10 |
+
# Log file settings
|
11 |
+
LOG_FILE = f"{conf.LOG_DIR}/app.log"
|
12 |
+
LOG_LEVEL = logging.INFO
|
13 |
+
|
14 |
+
def setup_logging():
|
15 |
+
"""Configure logging for the entire application."""
|
16 |
+
if not is_logging:
|
17 |
+
# Disable all logging if is_logging is False
|
18 |
+
logging.disable(logging.CRITICAL)
|
19 |
+
return
|
20 |
+
|
21 |
+
# Create a logger
|
22 |
+
logger = logging.getLogger()
|
23 |
+
logger.setLevel(LOG_LEVEL)
|
24 |
+
|
25 |
+
# Create a rotating file handler to store logs in a file
|
26 |
+
file_handler = RotatingFileHandler(LOG_FILE, maxBytes=5 * 1024 * 1024, backupCount=2, encoding='utf-8')
|
27 |
+
file_handler.setLevel(LOG_LEVEL)
|
28 |
+
|
29 |
+
# Create a stream handler to log to console
|
30 |
+
stream_handler = logging.StreamHandler()
|
31 |
+
stream_handler.setLevel(LOG_LEVEL)
|
32 |
+
|
33 |
+
# Define the format for log messages
|
34 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
35 |
+
file_handler.setFormatter(formatter)
|
36 |
+
stream_handler.setFormatter(formatter)
|
37 |
+
|
38 |
+
# Add the handlers to the logger
|
39 |
+
logger.addHandler(file_handler)
|
40 |
+
logger.addHandler(stream_handler)
|
41 |
+
|
42 |
+
logging.info("Logging setup complete")
|
43 |
+
|
44 |
+
# Call the function to set up logging when the module is imported
|
45 |
+
setup_logging()
|
app/main.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Body, Request, Form
|
2 |
+
import uvicorn
|
3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
4 |
+
from app.api.answer import answer_router
|
5 |
+
from app.api.upload import upload_router
|
6 |
+
|
7 |
+
from app import logging_config
|
8 |
+
from app.settings import Config
|
9 |
+
import warnings
|
10 |
+
warnings.filterwarnings("ignore")
|
11 |
+
|
12 |
+
conf = Config
|
13 |
+
|
14 |
+
def create_app():
|
15 |
+
|
16 |
+
app = FastAPI()
|
17 |
+
app.add_middleware(
|
18 |
+
CORSMiddleware,
|
19 |
+
allow_origins=["*"],
|
20 |
+
allow_credentials=True,
|
21 |
+
allow_methods=["*"],
|
22 |
+
allow_headers=["*"],
|
23 |
+
)
|
24 |
+
|
25 |
+
app.include_router(answer_router) #, prefix="/api/v1"
|
26 |
+
app.include_router(upload_router) #, prefix="/api/v1"
|
27 |
+
|
28 |
+
return app
|
29 |
+
|
30 |
+
|
31 |
+
app = create_app()
|
32 |
+
|
app/rag_pipeline/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from app.rag_pipeline.prompt_utils import contex_retriever_prompt, conversion_retriever_prompt
|
2 |
+
from app.rag_pipeline.retriever_chain import RetrieverChain
|
3 |
+
from app.rag_pipeline.chroma_client import get_chroma_client
|
app/rag_pipeline/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (396 Bytes). View file
|
|
app/rag_pipeline/__pycache__/chroma_client.cpython-310.pyc
ADDED
Binary file (617 Bytes). View file
|
|
app/rag_pipeline/__pycache__/model_initializer.cpython-310.pyc
ADDED
Binary file (1.5 kB). View file
|
|
app/rag_pipeline/__pycache__/model_loader.cpython-310.pyc
ADDED
Binary file (3.37 kB). View file
|
|
app/rag_pipeline/__pycache__/prompt_utils.cpython-310.pyc
ADDED
Binary file (941 Bytes). View file
|
|
app/rag_pipeline/__pycache__/retriever_chain.cpython-310.pyc
ADDED
Binary file (2.75 kB). View file
|
|
app/rag_pipeline/__pycache__/retriver_chain.cpython-310.pyc
ADDED
Binary file (2.87 kB). View file
|
|
app/rag_pipeline/chroma_client.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_chroma import Chroma
|
2 |
+
import logging
|
3 |
+
|
4 |
+
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
|
7 |
+
def get_chroma_client(collection_name, embedding_function, persist_directory):
|
8 |
+
try:
|
9 |
+
logging.info(f"Setting up chroma client")
|
10 |
+
return Chroma(collection_name=collection_name,
|
11 |
+
embedding_function=embedding_function,
|
12 |
+
persist_directory=persist_directory)
|
13 |
+
except Exception as e:
|
14 |
+
logging.error(f"Failed to initialize Chroma client: {e}")
|
15 |
+
raise
|
app/rag_pipeline/model_initializer.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
4 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
5 |
+
from app.rag_pipeline.model_loader import load_model
|
6 |
+
from langchain_huggingface import HuggingFacePipeline
|
7 |
+
from app.settings import Config
|
8 |
+
conf = Config()
|
9 |
+
|
10 |
+
CACHE_DIR = conf.CACHE_DIR
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
os.environ["HUGGINGFACEHUB_API_TOKEN"] = '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'
|
15 |
+
|
16 |
+
def initialize_models(openai_api_key=None,model_id=None, model_basename=None):
|
17 |
+
"""
|
18 |
+
Initializes embedding and chat model based on the OpenAI API key availability.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
tuple: (embedding_model, llm_model)
|
22 |
+
"""
|
23 |
+
|
24 |
+
try:
|
25 |
+
if openai_api_key:
|
26 |
+
embedding_model = OpenAIEmbeddings(api_key=openai_api_key)
|
27 |
+
llm_model = ChatOpenAI(api_key=openai_api_key)
|
28 |
+
logger.info("Using OpenAI models.")
|
29 |
+
else:
|
30 |
+
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", # all-mpnet-base-v2
|
31 |
+
model_kwargs={'device': 'cpu'},
|
32 |
+
encode_kwargs={'normalize_embeddings': False},
|
33 |
+
cache_folder = CACHE_DIR
|
34 |
+
)
|
35 |
+
# llm_model = load_model(device_type="cpu", model_id=model_id, model_basename=model_basename, LOGGING=logger)
|
36 |
+
llm_model = HuggingFacePipeline.from_model_id(
|
37 |
+
model_id= "gpt2", #"google/flan-t5-small",
|
38 |
+
task="text-generation",
|
39 |
+
)
|
40 |
+
|
41 |
+
#TheBloke/Mistral-7B-v0.1-GGUF
|
42 |
+
#HuggingFaceH4/zephyr-7b-beta
|
43 |
+
|
44 |
+
logger.info("Using Hugging Face embeddings and local LLM model.")
|
45 |
+
|
46 |
+
return embedding_model, llm_model
|
47 |
+
|
48 |
+
except Exception as e:
|
49 |
+
logger.error(f"Error initializing models: {e}")
|
50 |
+
raise
|
app/rag_pipeline/model_loader.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from auto_gptq import AutoGPTQForCausalLM
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
+
# from langchain.llms import LlamaCpp
|
5 |
+
from langchain_community.llms import LlamaCpp
|
6 |
+
from transformers import (
|
7 |
+
AutoModelForCausalLM,
|
8 |
+
AutoTokenizer,
|
9 |
+
LlamaForCausalLM,
|
10 |
+
LlamaTokenizer,
|
11 |
+
)
|
12 |
+
from langchain_community.llms import HuggingFacePipeline
|
13 |
+
from langchain.callbacks.manager import CallbackManager
|
14 |
+
from transformers import GenerationConfig, pipeline
|
15 |
+
import torch
|
16 |
+
import os
|
17 |
+
from app.settings import Config
|
18 |
+
|
19 |
+
conf = Config()
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
MODELS_PATH = conf.MODELS_PATH
|
25 |
+
|
26 |
+
CONTEXT_WINDOW_SIZE = 2048
|
27 |
+
MAX_NEW_TOKENS = 2048
|
28 |
+
N_BATCH= 512
|
29 |
+
N_GPU_LAYERS = 1
|
30 |
+
|
31 |
+
CACHE_DIR = conf.CACHE_DIR #"./models/"
|
32 |
+
|
33 |
+
os.environ["HUGGINGFACEHUB_API_TOKEN"] = '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'
|
34 |
+
|
35 |
+
def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, logging):
|
36 |
+
|
37 |
+
try:
|
38 |
+
logging.info("Using Llamacpp for GGUF/GGML quantized models")
|
39 |
+
model_path = hf_hub_download(
|
40 |
+
repo_id=model_id,
|
41 |
+
filename=model_basename,
|
42 |
+
resume_download=True,
|
43 |
+
# force_download=True,
|
44 |
+
cache_dir=MODELS_PATH,
|
45 |
+
)
|
46 |
+
kwargs = {
|
47 |
+
"model_path": model_path,
|
48 |
+
"n_ctx": CONTEXT_WINDOW_SIZE,
|
49 |
+
"max_tokens": MAX_NEW_TOKENS,
|
50 |
+
"n_batch": N_BATCH, # set this based on your GPU & CPU RAM
|
51 |
+
}
|
52 |
+
if device_type.lower() == "mps":
|
53 |
+
kwargs["n_gpu_layers"] = 1
|
54 |
+
if device_type.lower() == "cuda":
|
55 |
+
kwargs["n_gpu_layers"] = N_GPU_LAYERS # set this based on your GPU
|
56 |
+
|
57 |
+
return LlamaCpp(**kwargs)
|
58 |
+
except:
|
59 |
+
if "ggml" in model_basename:
|
60 |
+
logging.INFO("If you were using GGML model, LLAMA-CPP Dropped Support, Use GGUF Instead")
|
61 |
+
return None
|
62 |
+
|
63 |
+
|
64 |
+
def load_quantized_model_qptq(model_id, model_basename, device_type, logging):
|
65 |
+
logging.info("Using AutoGPTQForCausalLM for quantized models")
|
66 |
+
|
67 |
+
if ".safetensors" in model_basename:
|
68 |
+
# Remove the ".safetensors" ending if present
|
69 |
+
model_basename = model_basename.replace(".safetensors", "")
|
70 |
+
|
71 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
72 |
+
logging.info("Tokenizer loaded")
|
73 |
+
|
74 |
+
model = AutoGPTQForCausalLM.from_quantized(
|
75 |
+
model_id,
|
76 |
+
model_basename=model_basename,
|
77 |
+
use_safetensors=True,
|
78 |
+
trust_remote_code=True,
|
79 |
+
device_map="auto",
|
80 |
+
use_triton=False,
|
81 |
+
quantize_config=None,
|
82 |
+
)
|
83 |
+
return model, tokenizer
|
84 |
+
|
85 |
+
def load_full_model(model_id, model_basename, device_type, logging):
|
86 |
+
|
87 |
+
if device_type.lower() in ["mps", "cpu"]:
|
88 |
+
logging.info("Using LlamaTokenizer")
|
89 |
+
tokenizer = LlamaTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR, use_auth_token=os.environ["HUGGINGFACEHUB_API_TOKEN"]) #
|
90 |
+
model = LlamaForCausalLM.from_pretrained(model_id, cache_dir=CACHE_DIR, use_auth_token=os.environ["HUGGINGFACEHUB_API_TOKEN"]) #, cache_dir=CACHE_DIR
|
91 |
+
else:
|
92 |
+
logging.info("Using AutoModelForCausalLM for full models")
|
93 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR, use_auth_token=os.environ["HUGGINGFACEHUB_API_TOKEN"]) #, cache_dir=CACHE_DIR
|
94 |
+
logging.info("Tokenizer loaded")
|
95 |
+
model = AutoModelForCausalLM.from_pretrained(
|
96 |
+
model_id,
|
97 |
+
device_map="auto",
|
98 |
+
torch_dtype=torch.float16,
|
99 |
+
low_cpu_mem_usage=True,
|
100 |
+
cache_dir=MODELS_PATH,
|
101 |
+
use_auth_token=os.environ["HUGGINGFACEHUB_API_TOKEN"]
|
102 |
+
# trust_remote_code=True, # set these if you are using NVIDIA GPU
|
103 |
+
# load_in_4bit=True,
|
104 |
+
# bnb_4bit_quant_type="nf4",
|
105 |
+
# bnb_4bit_compute_dtype=torch.float16,
|
106 |
+
# max_memory={0: "15GB"} # Uncomment this line with you encounter CUDA out of memory errors
|
107 |
+
)
|
108 |
+
model.tie_weights()
|
109 |
+
return model, tokenizer
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
def load_model(device_type, model_id, model_basename=None, LOGGING=logger):
|
114 |
+
logger.info(f"Loading Model: {model_id}, on: {device_type}")
|
115 |
+
logger.info("This action can take a few minutes!")
|
116 |
+
|
117 |
+
if model_basename is not None:
|
118 |
+
if ".gguf" in model_basename.lower():
|
119 |
+
llm = load_quantized_model_gguf_ggml(
|
120 |
+
model_id, model_basename, device_type, LOGGING)
|
121 |
+
return llm
|
122 |
+
elif ".ggml" in model_basename.lower():
|
123 |
+
model, tokenizer = load_quantized_model_gguf_ggml(
|
124 |
+
model_id, model_basename, device_type, LOGGING)
|
125 |
+
else:
|
126 |
+
model, tokenizer = load_quantized_model_qptq(
|
127 |
+
model_id, model_basename, device_type, LOGGING)
|
128 |
+
else:
|
129 |
+
model, tokenizer = load_full_model(
|
130 |
+
model_id, model_basename, device_type, LOGGING)
|
131 |
+
|
132 |
+
# Load configuration from the model to avoid warnings
|
133 |
+
generation_config = GenerationConfig.from_pretrained(model_id)
|
134 |
+
|
135 |
+
pipe = pipeline(
|
136 |
+
"text-generation",
|
137 |
+
model=model,
|
138 |
+
tokenizer=tokenizer,
|
139 |
+
max_length=MAX_NEW_TOKENS,
|
140 |
+
temperature=0.1,
|
141 |
+
# top_p=0.95,
|
142 |
+
repetition_penalty=1.15,
|
143 |
+
generation_config=generation_config,
|
144 |
+
)
|
145 |
+
|
146 |
+
local_llm = HuggingFacePipeline(pipeline=pipe)
|
147 |
+
logger.info("Local LLM Loaded")
|
148 |
+
return local_llm
|
app/rag_pipeline/prompt_utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
2 |
+
|
3 |
+
|
4 |
+
contex_retriever_prompt = ChatPromptTemplate.from_messages([
|
5 |
+
MessagesPlaceholder(variable_name="chat_history"),
|
6 |
+
("user", "{input}"),
|
7 |
+
("user", "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation")
|
8 |
+
])
|
9 |
+
|
10 |
+
|
11 |
+
conversion_retriever_prompt = ChatPromptTemplate.from_messages([
|
12 |
+
("system",
|
13 |
+
"Answer the user's questions based on the below context:\n\n{context}"),
|
14 |
+
MessagesPlaceholder(variable_name="chat_history"),
|
15 |
+
("user", "{input}"),
|
16 |
+
])
|
17 |
+
|
18 |
+
|
19 |
+
system_prompt = (
|
20 |
+
"You are an assistant specializing in answering questions accurately based on provided context. "
|
21 |
+
"Use the context to answer the question concisely. If the answer is not found in the context, respond with 'I'm not sure'."
|
22 |
+
"\n\n"
|
23 |
+
"Context:\n{context}\n\n"
|
24 |
+
)
|
25 |
+
|
26 |
+
qa_prompt = ChatPromptTemplate.from_messages(
|
27 |
+
[
|
28 |
+
("system", system_prompt),
|
29 |
+
("user", "{input}"),
|
30 |
+
]
|
31 |
+
)
|
app/rag_pipeline/retriever_chain.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
#import create_history_aware_retriever,
|
3 |
+
from langchain.chains import create_retrieval_chain
|
4 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
5 |
+
from app.rag_pipeline.prompt_utils import qa_prompt
|
6 |
+
from app.rag_pipeline.chroma_client import get_chroma_client
|
7 |
+
from app.settings import Config
|
8 |
+
|
9 |
+
# from prompt_utils import qa_prompt
|
10 |
+
# from chroma_client import get_chroma_client
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
# import sys
|
16 |
+
# import os
|
17 |
+
|
18 |
+
# parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
19 |
+
# sys.path.insert(0, parent_dir)
|
20 |
+
# from settings import Config
|
21 |
+
|
22 |
+
|
23 |
+
conf = Config()
|
24 |
+
|
25 |
+
|
26 |
+
MODELS_PATH = conf.MODELS_PATH #'/models'
|
27 |
+
|
28 |
+
CONTEXT_WINDOW_SIZE = 2048
|
29 |
+
MAX_NEW_TOKENS = 2048
|
30 |
+
N_BATCH= 512
|
31 |
+
N_GPU_LAYERS = 1
|
32 |
+
|
33 |
+
MODEL_ID = conf.MODEL_ID #"TheBloke/Mistral-7B-v0.1-GGUF"
|
34 |
+
MODEL_BASENAME = conf.MODEL_BASENAME # "mistral-7b-v0.1.Q4_0.gguf"
|
35 |
+
device_type = 'cpu'
|
36 |
+
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
class RetrieverChain:
|
40 |
+
def __init__(self, collection_name, embedding_function, persist_directory):
|
41 |
+
try:
|
42 |
+
self.vector_db = get_chroma_client(collection_name, embedding_function, persist_directory)
|
43 |
+
except Exception as e:
|
44 |
+
logger.error(f"Error creating RetrieverChain: {e}")
|
45 |
+
raise
|
46 |
+
|
47 |
+
def get_retriever(self):
|
48 |
+
try:
|
49 |
+
retriever = self.vector_db.as_retriever(search_type="mmr", search_kwargs={"k": 5, "fetch_k": 2})
|
50 |
+
|
51 |
+
return retriever
|
52 |
+
except Exception as e:
|
53 |
+
logger.error(f"Failed to get retriever: {e}")
|
54 |
+
raise
|
55 |
+
|
56 |
+
def get_conversational_rag_chain(self, llm):
|
57 |
+
try:
|
58 |
+
|
59 |
+
if self.get_retriever is None:
|
60 |
+
logger.error(f"Retriever must not be None")
|
61 |
+
raise ValueError("Retriever must not be None")
|
62 |
+
if llm is None:
|
63 |
+
logger.error(f"Model must not be None")
|
64 |
+
raise ValueError("Model must not be None")
|
65 |
+
|
66 |
+
|
67 |
+
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
68 |
+
return create_retrieval_chain(self.get_retriever(), question_answer_chain)
|
69 |
+
except Exception as e:
|
70 |
+
logger.error(f"Error creating RAG chain: {e}")
|
71 |
+
raise
|
72 |
+
|
73 |
+
|
74 |
+
def get_relevent_docs(self, user_input):
|
75 |
+
|
76 |
+
try:
|
77 |
+
docs = self.vector_db.as_retriever(search_type="mmr", search_kwargs={"k": 6, "fetch_k": 3}).get_relevant_documents(user_input)
|
78 |
+
logger.info(f"Relevent documents for {user_input}: {docs}")
|
79 |
+
# Access the retrieved documents
|
80 |
+
|
81 |
+
# print("Relevent Docs")
|
82 |
+
# for doc in docs:
|
83 |
+
# print(doc.page_content) # Access the original text
|
84 |
+
# print(doc.metadata) # Access any metadata associated with the document
|
85 |
+
# print("Relevent Docs end")
|
86 |
+
return docs
|
87 |
+
|
88 |
+
except Exception as e:
|
89 |
+
logger.error(f"Error getting response: {e}")
|
90 |
+
raise
|
91 |
+
|
92 |
+
def get_response(self, user_input, llm):
|
93 |
+
try:
|
94 |
+
qa_rag_chain = self.get_conversational_rag_chain(llm)
|
95 |
+
response = qa_rag_chain.invoke({"input": user_input})
|
96 |
+
return response['answer']
|
97 |
+
except Exception as e:
|
98 |
+
logger.error(f"Error getting response: {e}")
|
99 |
+
raise
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
# if __name__ == "__main__":
|
105 |
+
# import os
|
106 |
+
# from model_initializer import initialize_models
|
107 |
+
|
108 |
+
|
109 |
+
# parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
110 |
+
|
111 |
+
# openai_api_key = conf.API_KEY
|
112 |
+
|
113 |
+
# embedding_model, llm_model = initialize_models(openai_api_key,model_id=MODEL_ID, model_basename=MODEL_BASENAME)
|
114 |
+
|
115 |
+
# print(f"embeddi_modelng: {embedding_model}")
|
116 |
+
# print(f"llm_model: {llm_model}")
|
117 |
+
|
118 |
+
# collection_name = 'AI_assignment'
|
119 |
+
|
120 |
+
# persist_directory = f'D:/AI Assignment/vector_store'
|
121 |
+
# print(f"persist_directory: {persist_directory}")
|
122 |
+
# while True:
|
123 |
+
# print("Enter query: ")
|
124 |
+
# user_query = input()
|
125 |
+
# if user_query.lower() == 'exit':
|
126 |
+
# break
|
127 |
+
|
128 |
+
# retriever_qa = RetrieverChain(
|
129 |
+
# collection_name=collection_name, embedding_function=embedding_model, persist_directory=persist_directory)
|
130 |
+
# response = retriever_qa.get_response(user_input = user_query, llm= llm_model)
|
131 |
+
# print(f"Response: {response}")
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
app/settings.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
|
4 |
+
from pathlib import Path
|
5 |
+
env_path = Path(__file__).resolve().parent.parent / '.env'
|
6 |
+
load_dotenv(dotenv_path=env_path, override=True)
|
7 |
+
|
8 |
+
class Config:
|
9 |
+
API_KEY = os.getenv('OPENAI_API_KEY')
|
10 |
+
MODEL_ID = os.getenv('MODEL_ID')
|
11 |
+
MODEL_BASENAME = os.getenv('MODEL_BASENAME')
|
12 |
+
COLLECTION_NAME = os.getenv('COLLECTION_NAME')
|
13 |
+
|
14 |
+
PERSIST_DIRECTORY = os.path.join(os.path.dirname(__file__),'..','vector_store')
|
15 |
+
os.makedirs(PERSIST_DIRECTORY, exist_ok=True)
|
16 |
+
|
17 |
+
UPLOAD_DIR = os.path.join(os.path.dirname(__file__),'..','uploads')
|
18 |
+
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
19 |
+
|
20 |
+
LOG_DIR = os.path.join(os.path.dirname(__file__),'..','log_dir')
|
21 |
+
os.makedirs(LOG_DIR, exist_ok=True)
|
22 |
+
|
23 |
+
MODELS_PATH = os.path.join(os.path.dirname(__file__),'..','models')
|
24 |
+
|
25 |
+
CACHE_DIR = os.path.join(os.path.dirname(__file__),'..','models')
|
26 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
27 |
+
# MODELS_PATH = '/models'
|
28 |
+
|
29 |
+
MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
30 |
+
MODEL_KWARGS = {'device': 'cpu'}
|
31 |
+
ENCODE_KWARGS = {'normalize_embeddings': False}
|
32 |
+
CHUNK_SIZE = 1024
|
33 |
+
CHUNK_OVERLAP = 200
|
requirements.txt
ADDED
Binary file (918 Bytes). View file
|
|