julienserbanescu-rag / queryrun.py
DaJulster's picture
Update queryrun.py
a18c77d verified
import cohere
import numpy as np
import faiss
import pickle
import os
import traceback # Import traceback for detailed error printing
from dotenv import load_dotenv
from langchain_community.docstore.document import Document
# Corrected import based on the deprecation warning
from langchain_community.docstore.in_memory import InMemoryDocstore
from openai import OpenAI
# Load environment variables
load_dotenv()
cohere_api_key = os.getenv("COHEREAPIKEY")
api_key = os.getenv("DEEPKEY")
# Initialize OpenAI client with minimal configuration
client = OpenAI(
api_key=api_key,
base_url="https://api.deepseek.com/v1"
)
# Initialize Cohere client
if not cohere_api_key:
raise ValueError("COHERE_API_KEY not found in environment variables")
co = cohere.Client(cohere_api_key)
# --- Custom Cohere Embeddings Class (for query embedding) ---
class CohereEmbeddingsForQuery:
def __init__(self, client):
self.client = client
self.embed_dim = self._get_embed_dim()
def _get_embed_dim(self):
try:
response = self.client.embed(
texts=["test"], model="embed-english-v3.0", input_type="search_query"
)
return len(response.embeddings[0])
except Exception as e:
print(f"Warning: Could not determine embedding dimension automatically: {e}. Defaulting to 4096.")
return 4096
def embed_query(self, text):
try:
# Ensure text is properly encoded as a string
if not isinstance(text, str):
try:
text = str(text)
except UnicodeEncodeError:
# If there's an encoding error, try to normalize the text
import unicodedata
text = unicodedata.normalize('NFKD', str(text))
response = self.client.embed(
texts=[text],
model="embed-english-v3.0",
input_type="search_query"
)
if hasattr(response, 'embeddings') and len(response.embeddings) > 0:
return np.array(response.embeddings[0]).astype('float32')
else:
print("Warning: No query embedding found in the response. Returning zero vector.")
return np.zeros(self.embed_dim, dtype=np.float32)
except Exception as e:
print(f"Query embedding error: {e}")
return np.zeros(self.embed_dim, dtype=np.float32)
# --- FAISS Query System ---
class FAISSQuerySystem:
def __init__(self, persist_dir='docs/faiss/'):
self.persist_dir = persist_dir
self.index = None
self.documents = [] # List to hold LangChain Document objects
self.metadata_list = [] # List to hold metadata dictionaries
self.embedding_function = CohereEmbeddingsForQuery(co) # Use the query-specific class
self.load_index()
def stream_chat_completions(self, input_text):
# Ensure input_text is properly encoded as a string
if not isinstance(input_text, str):
try:
input_text = str(input_text)
except UnicodeEncodeError:
# If there's an encoding error, try to normalize the text
import unicodedata
input_text = unicodedata.normalize('NFKD', str(input_text))
response = client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": "Your job is to make text more appealing by adding emojis, formatting, and other enhancements. Do not include any awkward markup though."},
{"role": "user", "content": input_text},
],
stream=False
)
try:
resp = response.choices[0].message.content.split("\n---")[1]
except:
resp = response.choices[0].message.content
# Extracting just the core content without the extra sections
resp = resp.replace('**', '') # Remove bold formatting
resp = resp.replace('*', '')
return resp
def load_index(self):
"""Load the FAISS index and associated document/metadata files"""
faiss_index_path = os.path.join(self.persist_dir, "index.faiss")
pkl_path = os.path.join(self.persist_dir, "index.pkl")
metadata_path = os.path.join(self.persist_dir, "metadata.pkl")
print(f"Loading FAISS index from: {faiss_index_path}")
print(f"Loading docstore info from: {pkl_path}")
print(f"Loading separate metadata from: {metadata_path}")
if not os.path.exists(faiss_index_path) or not os.path.exists(pkl_path):
raise FileNotFoundError(f"Required index files (index.faiss, index.pkl) not found in {self.persist_dir}")
try:
# 1. Load FAISS index
self.index = faiss.read_index(faiss_index_path)
print(f"FAISS index loaded successfully with {self.index.ntotal} vectors.")
# 2. Load LangChain docstore pickle file
with open(pkl_path, 'rb') as f:
try:
docstore, index_to_docstore_id = pickle.load(f)
except (KeyError, AttributeError) as e:
print(f"Error loading pickle file: {str(e)}")
print("This might be due to a Pydantic version mismatch.")
print("Attempting to recreate the index...")
# Delete the incompatible files
if os.path.exists(faiss_index_path):
os.remove(faiss_index_path)
if os.path.exists(pkl_path):
os.remove(pkl_path)
if os.path.exists(metadata_path):
os.remove(metadata_path)
# Recreate the index
from test import main as recreate_index
recreate_index()
# Try loading again
with open(pkl_path, 'rb') as f:
docstore, index_to_docstore_id = pickle.load(f)
except UnicodeDecodeError:
print("Unicode decode error when loading pickle file. Attempting to handle special characters...")
# Try to handle the Unicode decode error
import codecs
with codecs.open(pkl_path, 'rb', encoding='utf-8', errors='replace') as f:
docstore, index_to_docstore_id = pickle.load(f)
# Verify the types after loading
print(f"Docstore object loaded. Type: {type(docstore)}")
print(f"Index-to-ID mapping loaded. Type: {type(index_to_docstore_id)}")
# Now this line should work
if isinstance(index_to_docstore_id, dict):
print(f"Mapping contains {len(index_to_docstore_id)} entries.")
else:
# This case should ideally not happen now, but good to have a check
raise TypeError(f"Expected index_to_docstore_id to be a dict, but got {type(index_to_docstore_id)}")
if not isinstance(docstore, InMemoryDocstore):
# Add a check for the docstore type too
print(f"Warning: Expected docstore to be InMemoryDocstore, but got {type(docstore)}")
# 3. Reconstruct the list of documents in FAISS index order
self.documents = []
num_vectors = self.index.ntotal
# Verify consistency
if num_vectors != len(index_to_docstore_id):
print(f"Warning: FAISS index size ({num_vectors}) does not match mapping size ({len(index_to_docstore_id)}). Reconstruction might be incomplete.")
print("Reconstructing document list...")
reconstructed_count = 0
missing_in_mapping = 0
missing_in_docstore = 0
# Ensure docstore has the 'search' method needed.
if not hasattr(docstore, 'search'):
raise AttributeError(f"Loaded docstore object (type: {type(docstore)}) does not have a 'search' method.")
for i in range(num_vectors):
docstore_id = index_to_docstore_id.get(i)
if docstore_id:
# Use the correct method for InMemoryDocstore to retrieve by ID
doc = docstore.search(docstore_id)
if doc:
self.documents.append(doc)
reconstructed_count += 1
else:
print(f"Warning: Document with ID '{docstore_id}' (for FAISS index {i}) not found in the loaded docstore.")
missing_in_docstore += 1
else:
print(f"Warning: No docstore ID found in mapping for FAISS index {i}.")
missing_in_mapping += 1
print(f"Successfully reconstructed {reconstructed_count} documents.")
if missing_in_mapping > 0: print(f"Could not find mapping for {missing_in_mapping} indices.")
if missing_in_docstore > 0: print(f"Could not find {missing_in_docstore} documents in docstore despite having mapping.")
# 4. Load the separate metadata list
if os.path.exists(metadata_path):
with open(metadata_path, 'rb') as f:
self.metadata_list = pickle.load(f)
print(f"Loaded separate metadata list with {len(self.metadata_list)} entries.")
if len(self.metadata_list) != len(self.documents):
print(f"Warning: Mismatch between reconstructed documents ({len(self.documents)}) and loaded metadata list ({len(self.metadata_list)}).")
print("Falling back to using metadata attached to Document objects if available.")
self.metadata_list = [getattr(doc, 'metadata', {}) for doc in self.documents]
elif not self.documents and self.metadata_list:
print("Warning: Loaded metadata but no documents were reconstructed. Discarding metadata.")
self.metadata_list = []
else:
print("Warning: Separate metadata file (metadata.pkl) not found.")
print("Attempting to use metadata attached to Document objects.")
self.metadata_list = [getattr(doc, 'metadata', {}) for doc in self.documents]
print(f"Final document count: {len(self.documents)}")
print(f"Final metadata count: {len(self.metadata_list)}")
except FileNotFoundError as e:
print(f"Error loading index files: {e}")
raise
except Exception as e:
print(f"An unexpected error occurred during index loading: {e}")
traceback.print_exc()
raise
def search(self, query, k=3):
"""Search the index and return relevant documents with metadata and scores"""
if not self.index or self.index.ntotal == 0:
print("Warning: FAISS index is not loaded or is empty.")
return []
if not self.documents:
print("Warning: No documents were successfully loaded.")
return []
actual_k = min(k, len(self.documents))
if actual_k == 0:
return []
# Ensure query is properly encoded as a string
if not isinstance(query, str):
try:
query = str(query)
except UnicodeEncodeError:
# If there's an encoding error, try to normalize the text
import unicodedata
query = unicodedata.normalize('NFKD', str(query))
query_embedding = self.embedding_function.embed_query(query)
if np.all(query_embedding == 0):
print("Warning: Query embedding failed, search may be ineffective.")
query_embedding_batch = np.array([query_embedding])
distances, indices = self.index.search(query_embedding_batch, actual_k)
results = []
retrieved_indices = indices[0]
for i, idx in enumerate(retrieved_indices):
if idx == -1:
continue
if idx < len(self.documents):
doc = self.documents[idx]
metadata = self.metadata_list[idx] if idx < len(self.metadata_list) else getattr(doc, 'metadata', {})
distance = distances[0][i]
similarity_score = 1.0 / (1.0 + distance) # Basic L2 -> Similarity
# Ensure content is properly encoded as a string
content = getattr(doc, 'page_content', str(doc))
if not isinstance(content, str):
try:
content = str(content)
except UnicodeEncodeError:
# If there's an encoding error, try to normalize the text
import unicodedata
content = unicodedata.normalize('NFKD', str(content))
results.append({
"content": content,
"metadata": metadata,
"score": float(similarity_score)
})
else:
print(f"Warning: Search returned index {idx} which is out of bounds for loaded documents ({len(self.documents)}).")
results.sort(key=lambda x: x['score'], reverse=True)
return results
def generate_response(self, query, context_docs):
"""Generate RAG response using Cohere's chat API"""
if not context_docs:
print("No context documents provided to generate_response.")
try:
response = co.chat(
message=f"I could not find relevant documents in my knowledge base to answer your question: '{query}'. Please try rephrasing or asking about topics covered in the source material.",
model="command-r-plus",
temperature=0.3,
preamble="You are an AI assistant explaining limitations."
)
return response.text
except Exception as e:
print(f"Error calling Cohere even without documents: {e}")
return "I could not find relevant documents and encountered an error trying to respond."
formatted_docs = []
# Process documents in batches to reduce memory usage
batch_size = 3
for i in range(0, len(context_docs), batch_size):
batch_end = min(i + batch_size, len(context_docs))
for j in range(i, batch_end):
doc = context_docs[j]
# Ensure content is properly encoded as a string
content = doc['content']
if not isinstance(content, str):
try:
content = str(content)
except UnicodeEncodeError:
# If there's an encoding error, try to normalize the text
import unicodedata
content = unicodedata.normalize('NFKD', str(content))
content_preview = content[:3000]
doc_info = f"Source: {doc['metadata'].get('source', 'Unknown')}\n"
doc_info += f"Type: {doc['metadata'].get('type', 'Unknown')}\n"
doc_info += f"Content Snippet: {content_preview}"
formatted_docs.append({"title": f"Document {j+1} (Source: {doc['metadata'].get('source', 'Unknown')})", "snippet": doc_info})
# Force garbage collection after each batch
import gc
gc.collect()
try:
response = co.chat(
message=query,
documents=formatted_docs,
model="command-r-plus",
temperature=0.3,
prompt_truncation='AUTO',
preamble="You are an expert AI assistant. Answer the user's question based *only* on the provided document snippets. Cite the source document number (e.g., [Document 1]) when using information from it. If the answer isn't in the documents, state that clearly."
)
return self.stream_chat_completions(response.text)
except Exception as e:
print(f"Error during Cohere chat API call: {e}")
traceback.print_exc()
return "Sorry, I encountered an error while trying to generate a response using the retrieved documents."
def main():
try:
# Initialize query system
query_system = FAISSQuerySystem() # Defaults to 'docs/faiss/'
# Interactive query loop
print("\n--- FAISS RAG Query System ---")
print("Ask questions about the content indexed from web, PDFs, and audio.")
print("Type 'exit' or 'quit' to stop.")
while True:
query = input("\nYour question: ")
if query.lower() in ('exit', 'quit'):
print("Exiting...")
break
if not query:
continue
try:
# 1. Search for relevant documents
print("Searching for relevant documents...")
docs = query_system.search(query, k=5) # Get top 5 results
if not docs:
print("Could not find relevant documents in the knowledge base.")
response = query_system.generate_response(query, [])
print("\nResponse:")
print("-" * 50)
print(response)
print("-" * 50)
continue
print(f"Found {len(docs)} relevant document chunks.")
# 2. Generate and display response using RAG
print("Generating response based on documents...")
response = query_system.generate_response(query, docs)
print("\nResponse:")
print("-" * 50)
print(response)
print("-" * 50)
# 3. Show sources (optional)
print("\nRetrieved Sources (Snippets):")
for i, doc in enumerate(docs, 1):
print(f"\n--- Source {i} ---")
print(f" Score: {doc['score']:.4f}")
print(f" Source File: {doc['metadata'].get('source', 'Unknown')}")
print(f" Type: {doc['metadata'].get('type', 'Unknown')}")
if 'page' in doc['metadata']:
print(f" Page (PDF): {doc['metadata']['page']}")
print(f" Content: {doc['content'][:250]}...")
except Exception as e:
print(f"\nAn error occurred while processing your query: {e}")
traceback.print_exc()
except FileNotFoundError as e:
print(f"\nInitialization Error: Could not find necessary index files.")
print(f"Details: {e}")
print("Please ensure you have run the indexing script first and the 'docs/faiss/' directory contains 'index.faiss' and 'index.pkl'.")
except Exception as e:
print(f"\nA critical initialization error occurred: {e}")
traceback.print_exc()
if __name__ == "__main__":
main()