SoumyaJ commited on
Commit
09eb7a4
·
verified ·
1 Parent(s): 77851e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -40
app.py CHANGED
@@ -2,26 +2,24 @@ from fastapi import FastAPI, UploadFile,File,HTTPException
2
  from fastapi.responses import JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from dotenv import load_dotenv
5
- from langchain_community.document_loaders import PyMuPDFLoader
6
  from langchain_text_splitters import RecursiveCharacterTextSplitter
7
  from langchain_huggingface import HuggingFaceEmbeddings
8
  from langchain_core.prompts import ChatPromptTemplate
9
  from langchain_core.output_parsers import StrOutputParser
10
  from langchain_groq import ChatGroq
11
- from langchain_pinecone import PineconeVectorStore
12
  from langchain_core.runnables import RunnablePassthrough
13
  from pathlib import Path
14
  import uvicorn
15
  import shutil
16
  import os
17
  import hashlib
18
- from pinecone import Pinecone
19
  import fitz
20
  import pytesseract
21
  from PIL import Image
22
  from langchain.schema import Document
 
23
  import io
24
- import time
25
 
26
  app = FastAPI()
27
 
@@ -36,21 +34,17 @@ app.add_middleware(
36
  UPLOAD_DIR = "uploads"
37
  os.makedirs(UPLOAD_DIR, exist_ok=True)
38
 
39
- os.environ["TOKENIZERS_PARALLELISM"] = "true"
40
-
41
- index_name = "pinecone-chatbot"
42
 
43
  load_dotenv()
44
  os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
45
- os.environ["PINECONE_API_KEY"] = os.getenv("PINECONE_API_KEY")
46
  os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
47
 
48
- llm = ChatGroq(model_name = "qwen-2.5-32b")
49
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
50
 
51
  prompt = '''You are given a context below. Use it to answer the question that follows.
52
- Provide a concise and factual response based on the context as below". If user mentions keywords such as "file","pdf", "document", please refer them as context.
53
- If you cannot find the answer, please reply *the answer cannot be found in the given context*
54
 
55
  <context>
56
  {context}
@@ -61,9 +55,6 @@ Answer:'''
61
 
62
  parser = StrOutputParser()
63
 
64
- pc = Pinecone(api_key=os.environ.get("PINECONE_API_KEY"))
65
- index = pc.Index(name=index_name)
66
-
67
  def generate_file_id(file_path):
68
  hasher = hashlib.md5()
69
  with open(file_path, "rb") as f:
@@ -71,16 +62,15 @@ def generate_file_id(file_path):
71
  return hasher.hexdigest()
72
 
73
  def delete_existing_embedding(file_id):
74
- index_stats = index.describe_index_stats()
75
- if index_stats["total_vector_count"] > 0:
76
- index.delete(delete_all=True)
77
 
78
  def tempUploadFile(filePath,file):
79
  with open(filePath,'wb') as buffer:
80
  shutil.copyfileobj(file.file, buffer)
81
 
82
  def loadAndSplitDocuments(filePath):
83
- loader = PyMuPDFLoader(filePath)
84
  docs = loader.load()
85
 
86
  splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=500)
@@ -118,30 +108,16 @@ def loadAndSplitPdfFile(filePath):
118
  final_chunks = splitter.split_documents(documents)
119
  return final_chunks
120
 
121
- def prepare_retriever(filePath = "", load_from_pinecone = False):
122
- if load_from_pinecone:
123
- vector_store = PineconeVectorStore.from_existing_index(index_name, embeddings)
124
  return vector_store.as_retriever(search_kwargs={"k": 5})
125
  elif filePath:
126
- doc_chunks = loadAndSplitPdfFile(filePath)
127
- vector_data = []
128
 
129
- for i, doc in enumerate(doc_chunks):
130
- embedding = embeddings.embed_query(doc.page_content)
131
- if embedding:
132
- metadata = {
133
- "text": doc.page_content,
134
- "source": str(doc.metadata.get("source", "unknown")),
135
- "page": int(doc.metadata.get("page", i)), # Add page info if available
136
- }
137
- vector_data.append((str(i), embedding, metadata))
138
- print(f"Upserting {len(vector_data)} records into Pinecone...")
139
-
140
- index.describe_index_stats()
141
- time.sleep(2)
142
-
143
- index.upsert(vectors=vector_data)
144
- print("Upsert complete")
145
 
146
  def get_retriever_chain(retriever):
147
  chat_prompt = ChatPromptTemplate.from_template(prompt)
@@ -156,6 +132,7 @@ def UploadFileInStore(file: UploadFile = File(...)):
156
  filePath = Path(UPLOAD_DIR) / file.filename
157
  tempUploadFile(filePath,file)
158
  file_id = generate_file_id(filePath)
 
159
  delete_existing_embedding(file_id)
160
  prepare_retriever(filePath)
161
 
@@ -166,7 +143,7 @@ def UploadFileInStore(file: UploadFile = File(...)):
166
 
167
  @app.get("/QnAFromPdf")
168
  async def QnAFromPdf(query: str):
169
- retriever = prepare_retriever(load_from_pinecone=True)
170
  chain = get_retriever_chain(retriever)
171
  response = chain.invoke(query)
172
  return response
 
2
  from fastapi.responses import JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from dotenv import load_dotenv
5
+ from langchain_community.document_loaders import PyMuPDFLoader, UnstructuredPDFLoader
6
  from langchain_text_splitters import RecursiveCharacterTextSplitter
7
  from langchain_huggingface import HuggingFaceEmbeddings
8
  from langchain_core.prompts import ChatPromptTemplate
9
  from langchain_core.output_parsers import StrOutputParser
10
  from langchain_groq import ChatGroq
 
11
  from langchain_core.runnables import RunnablePassthrough
12
  from pathlib import Path
13
  import uvicorn
14
  import shutil
15
  import os
16
  import hashlib
 
17
  import fitz
18
  import pytesseract
19
  from PIL import Image
20
  from langchain.schema import Document
21
+ from langchain_community.vectorstores import Chroma
22
  import io
 
23
 
24
  app = FastAPI()
25
 
 
34
  UPLOAD_DIR = "uploads"
35
  os.makedirs(UPLOAD_DIR, exist_ok=True)
36
 
37
+ persist_directory = "./chroma_db"
 
 
38
 
39
  load_dotenv()
40
  os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
 
41
  os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
42
 
43
+ llm = ChatGroq(model_name = "Llama3-8b-8192")
44
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
45
 
46
  prompt = '''You are given a context below. Use it to answer the question that follows.
47
+ Provide a concise and factual response. If the answer is not in the context, simply state "I don't know based on context provided."
 
48
 
49
  <context>
50
  {context}
 
55
 
56
  parser = StrOutputParser()
57
 
 
 
 
58
  def generate_file_id(file_path):
59
  hasher = hashlib.md5()
60
  with open(file_path, "rb") as f:
 
62
  return hasher.hexdigest()
63
 
64
  def delete_existing_embedding(file_id):
65
+ if os.path.exists(persist_directory):
66
+ shutil.rmtree(persist_directory)
 
67
 
68
  def tempUploadFile(filePath,file):
69
  with open(filePath,'wb') as buffer:
70
  shutil.copyfileobj(file.file, buffer)
71
 
72
  def loadAndSplitDocuments(filePath):
73
+ loader = UnstructuredPDFLoader(filePath)
74
  docs = loader.load()
75
 
76
  splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=500)
 
108
  final_chunks = splitter.split_documents(documents)
109
  return final_chunks
110
 
111
+ def prepare_retriever(filePath = "", load_from_chromadb = False):
112
+ if load_from_chromadb:
113
+ vector_store = Chroma(persist_directory=persist_directory, embedding_function= embeddings)
114
  return vector_store.as_retriever(search_kwargs={"k": 5})
115
  elif filePath:
116
+ doc_chunks = loadAndSplitPdfFile(filePath)
117
+ print(f"Loaded {len(doc_chunks)} documents from {filePath}")
118
 
119
+ vector_store = Chroma.from_documents(documents= doc_chunks, persist_directory=persist_directory, embedding_function= embeddings)
120
+ vector_store.persist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  def get_retriever_chain(retriever):
123
  chat_prompt = ChatPromptTemplate.from_template(prompt)
 
132
  filePath = Path(UPLOAD_DIR) / file.filename
133
  tempUploadFile(filePath,file)
134
  file_id = generate_file_id(filePath)
135
+
136
  delete_existing_embedding(file_id)
137
  prepare_retriever(filePath)
138
 
 
143
 
144
  @app.get("/QnAFromPdf")
145
  async def QnAFromPdf(query: str):
146
+ retriever = prepare_retriever(load_from_chromadb=True)
147
  chain = get_retriever_chain(retriever)
148
  response = chain.invoke(query)
149
  return response