Spaces:
Sleeping
Sleeping
import cohere | |
import numpy as np | |
import faiss | |
import pickle | |
import os | |
import traceback # Import traceback for detailed error printing | |
from dotenv import load_dotenv | |
from langchain_community.docstore.document import Document | |
# Corrected import based on the deprecation warning | |
from langchain_community.docstore.in_memory import InMemoryDocstore | |
from openai import OpenAI | |
# Load environment variables | |
load_dotenv() | |
cohere_api_key = os.getenv("COHEREAPIKEY") | |
api_key = os.getenv("DEEPKEY") | |
# Initialize OpenAI client with minimal configuration | |
client = OpenAI( | |
api_key=api_key, | |
base_url="https://api.deepseek.com/v1" | |
) | |
# Initialize Cohere client | |
if not cohere_api_key: | |
raise ValueError("COHERE_API_KEY not found in environment variables") | |
co = cohere.Client(cohere_api_key) | |
# --- Custom Cohere Embeddings Class (for query embedding) --- | |
class CohereEmbeddingsForQuery: | |
def __init__(self, client): | |
self.client = client | |
self.embed_dim = self._get_embed_dim() | |
def _get_embed_dim(self): | |
try: | |
response = self.client.embed( | |
texts=["test"], model="embed-english-v3.0", input_type="search_query" | |
) | |
return len(response.embeddings[0]) | |
except Exception as e: | |
print(f"Warning: Could not determine embedding dimension automatically: {e}. Defaulting to 4096.") | |
return 4096 | |
def embed_query(self, text): | |
try: | |
# Ensure text is properly encoded as a string | |
if not isinstance(text, str): | |
try: | |
text = str(text) | |
except UnicodeEncodeError: | |
# If there's an encoding error, try to normalize the text | |
import unicodedata | |
text = unicodedata.normalize('NFKD', str(text)) | |
response = self.client.embed( | |
texts=[text], | |
model="embed-english-v3.0", | |
input_type="search_query" | |
) | |
if hasattr(response, 'embeddings') and len(response.embeddings) > 0: | |
return np.array(response.embeddings[0]).astype('float32') | |
else: | |
print("Warning: No query embedding found in the response. Returning zero vector.") | |
return np.zeros(self.embed_dim, dtype=np.float32) | |
except Exception as e: | |
print(f"Query embedding error: {e}") | |
return np.zeros(self.embed_dim, dtype=np.float32) | |
# --- FAISS Query System --- | |
class FAISSQuerySystem: | |
def __init__(self, persist_dir='docs/faiss/'): | |
self.persist_dir = persist_dir | |
self.index = None | |
self.documents = [] # List to hold LangChain Document objects | |
self.metadata_list = [] # List to hold metadata dictionaries | |
self.embedding_function = CohereEmbeddingsForQuery(co) # Use the query-specific class | |
self.load_index() | |
def stream_chat_completions(self, input_text): | |
# Ensure input_text is properly encoded as a string | |
if not isinstance(input_text, str): | |
try: | |
input_text = str(input_text) | |
except UnicodeEncodeError: | |
# If there's an encoding error, try to normalize the text | |
import unicodedata | |
input_text = unicodedata.normalize('NFKD', str(input_text)) | |
response = client.chat.completions.create( | |
model="deepseek-chat", | |
messages=[ | |
{"role": "system", "content": "Your job is to make text more appealing by adding emojis, formatting, and other enhancements. Do not include any awkward markup though."}, | |
{"role": "user", "content": input_text}, | |
], | |
stream=False | |
) | |
try: | |
resp = response.choices[0].message.content.split("\n---")[1] | |
except: | |
resp = response.choices[0].message.content | |
# Extracting just the core content without the extra sections | |
resp = resp.replace('**', '') # Remove bold formatting | |
resp = resp.replace('*', '') | |
return resp | |
def load_index(self): | |
"""Load the FAISS index and associated document/metadata files""" | |
faiss_index_path = os.path.join(self.persist_dir, "index.faiss") | |
pkl_path = os.path.join(self.persist_dir, "index.pkl") | |
metadata_path = os.path.join(self.persist_dir, "metadata.pkl") | |
print(f"Loading FAISS index from: {faiss_index_path}") | |
print(f"Loading docstore info from: {pkl_path}") | |
print(f"Loading separate metadata from: {metadata_path}") | |
if not os.path.exists(faiss_index_path) or not os.path.exists(pkl_path): | |
raise FileNotFoundError(f"Required index files (index.faiss, index.pkl) not found in {self.persist_dir}") | |
try: | |
# 1. Load FAISS index | |
self.index = faiss.read_index(faiss_index_path) | |
print(f"FAISS index loaded successfully with {self.index.ntotal} vectors.") | |
# 2. Load LangChain docstore pickle file | |
with open(pkl_path, 'rb') as f: | |
try: | |
docstore, index_to_docstore_id = pickle.load(f) | |
except (KeyError, AttributeError) as e: | |
print(f"Error loading pickle file: {str(e)}") | |
print("This might be due to a Pydantic version mismatch.") | |
print("Attempting to recreate the index...") | |
# Delete the incompatible files | |
if os.path.exists(faiss_index_path): | |
os.remove(faiss_index_path) | |
if os.path.exists(pkl_path): | |
os.remove(pkl_path) | |
if os.path.exists(metadata_path): | |
os.remove(metadata_path) | |
# Recreate the index | |
from test import main as recreate_index | |
recreate_index() | |
# Try loading again | |
with open(pkl_path, 'rb') as f: | |
docstore, index_to_docstore_id = pickle.load(f) | |
except UnicodeDecodeError: | |
print("Unicode decode error when loading pickle file. Attempting to handle special characters...") | |
# Try to handle the Unicode decode error | |
import codecs | |
with codecs.open(pkl_path, 'rb', encoding='utf-8', errors='replace') as f: | |
docstore, index_to_docstore_id = pickle.load(f) | |
# Verify the types after loading | |
print(f"Docstore object loaded. Type: {type(docstore)}") | |
print(f"Index-to-ID mapping loaded. Type: {type(index_to_docstore_id)}") | |
# Now this line should work | |
if isinstance(index_to_docstore_id, dict): | |
print(f"Mapping contains {len(index_to_docstore_id)} entries.") | |
else: | |
# This case should ideally not happen now, but good to have a check | |
raise TypeError(f"Expected index_to_docstore_id to be a dict, but got {type(index_to_docstore_id)}") | |
if not isinstance(docstore, InMemoryDocstore): | |
# Add a check for the docstore type too | |
print(f"Warning: Expected docstore to be InMemoryDocstore, but got {type(docstore)}") | |
# 3. Reconstruct the list of documents in FAISS index order | |
self.documents = [] | |
num_vectors = self.index.ntotal | |
# Verify consistency | |
if num_vectors != len(index_to_docstore_id): | |
print(f"Warning: FAISS index size ({num_vectors}) does not match mapping size ({len(index_to_docstore_id)}). Reconstruction might be incomplete.") | |
print("Reconstructing document list...") | |
reconstructed_count = 0 | |
missing_in_mapping = 0 | |
missing_in_docstore = 0 | |
# Ensure docstore has the 'search' method needed. | |
if not hasattr(docstore, 'search'): | |
raise AttributeError(f"Loaded docstore object (type: {type(docstore)}) does not have a 'search' method.") | |
for i in range(num_vectors): | |
docstore_id = index_to_docstore_id.get(i) | |
if docstore_id: | |
# Use the correct method for InMemoryDocstore to retrieve by ID | |
doc = docstore.search(docstore_id) | |
if doc: | |
self.documents.append(doc) | |
reconstructed_count += 1 | |
else: | |
print(f"Warning: Document with ID '{docstore_id}' (for FAISS index {i}) not found in the loaded docstore.") | |
missing_in_docstore += 1 | |
else: | |
print(f"Warning: No docstore ID found in mapping for FAISS index {i}.") | |
missing_in_mapping += 1 | |
print(f"Successfully reconstructed {reconstructed_count} documents.") | |
if missing_in_mapping > 0: print(f"Could not find mapping for {missing_in_mapping} indices.") | |
if missing_in_docstore > 0: print(f"Could not find {missing_in_docstore} documents in docstore despite having mapping.") | |
# 4. Load the separate metadata list | |
if os.path.exists(metadata_path): | |
with open(metadata_path, 'rb') as f: | |
self.metadata_list = pickle.load(f) | |
print(f"Loaded separate metadata list with {len(self.metadata_list)} entries.") | |
if len(self.metadata_list) != len(self.documents): | |
print(f"Warning: Mismatch between reconstructed documents ({len(self.documents)}) and loaded metadata list ({len(self.metadata_list)}).") | |
print("Falling back to using metadata attached to Document objects if available.") | |
self.metadata_list = [getattr(doc, 'metadata', {}) for doc in self.documents] | |
elif not self.documents and self.metadata_list: | |
print("Warning: Loaded metadata but no documents were reconstructed. Discarding metadata.") | |
self.metadata_list = [] | |
else: | |
print("Warning: Separate metadata file (metadata.pkl) not found.") | |
print("Attempting to use metadata attached to Document objects.") | |
self.metadata_list = [getattr(doc, 'metadata', {}) for doc in self.documents] | |
print(f"Final document count: {len(self.documents)}") | |
print(f"Final metadata count: {len(self.metadata_list)}") | |
except FileNotFoundError as e: | |
print(f"Error loading index files: {e}") | |
raise | |
except Exception as e: | |
print(f"An unexpected error occurred during index loading: {e}") | |
traceback.print_exc() | |
raise | |
def search(self, query, k=3): | |
"""Search the index and return relevant documents with metadata and scores""" | |
if not self.index or self.index.ntotal == 0: | |
print("Warning: FAISS index is not loaded or is empty.") | |
return [] | |
if not self.documents: | |
print("Warning: No documents were successfully loaded.") | |
return [] | |
actual_k = min(k, len(self.documents)) | |
if actual_k == 0: | |
return [] | |
# Ensure query is properly encoded as a string | |
if not isinstance(query, str): | |
try: | |
query = str(query) | |
except UnicodeEncodeError: | |
# If there's an encoding error, try to normalize the text | |
import unicodedata | |
query = unicodedata.normalize('NFKD', str(query)) | |
query_embedding = self.embedding_function.embed_query(query) | |
if np.all(query_embedding == 0): | |
print("Warning: Query embedding failed, search may be ineffective.") | |
query_embedding_batch = np.array([query_embedding]) | |
distances, indices = self.index.search(query_embedding_batch, actual_k) | |
results = [] | |
retrieved_indices = indices[0] | |
for i, idx in enumerate(retrieved_indices): | |
if idx == -1: | |
continue | |
if idx < len(self.documents): | |
doc = self.documents[idx] | |
metadata = self.metadata_list[idx] if idx < len(self.metadata_list) else getattr(doc, 'metadata', {}) | |
distance = distances[0][i] | |
similarity_score = 1.0 / (1.0 + distance) # Basic L2 -> Similarity | |
# Ensure content is properly encoded as a string | |
content = getattr(doc, 'page_content', str(doc)) | |
if not isinstance(content, str): | |
try: | |
content = str(content) | |
except UnicodeEncodeError: | |
# If there's an encoding error, try to normalize the text | |
import unicodedata | |
content = unicodedata.normalize('NFKD', str(content)) | |
results.append({ | |
"content": content, | |
"metadata": metadata, | |
"score": float(similarity_score) | |
}) | |
else: | |
print(f"Warning: Search returned index {idx} which is out of bounds for loaded documents ({len(self.documents)}).") | |
results.sort(key=lambda x: x['score'], reverse=True) | |
return results | |
def generate_response(self, query, context_docs): | |
"""Generate RAG response using Cohere's chat API""" | |
if not context_docs: | |
print("No context documents provided to generate_response.") | |
try: | |
response = co.chat( | |
message=f"I could not find relevant documents in my knowledge base to answer your question: '{query}'. Please try rephrasing or asking about topics covered in the source material.", | |
model="command-r-plus", | |
temperature=0.3, | |
preamble="You are an AI assistant explaining limitations." | |
) | |
return response.text | |
except Exception as e: | |
print(f"Error calling Cohere even without documents: {e}") | |
return "I could not find relevant documents and encountered an error trying to respond." | |
formatted_docs = [] | |
# Process documents in batches to reduce memory usage | |
batch_size = 3 | |
for i in range(0, len(context_docs), batch_size): | |
batch_end = min(i + batch_size, len(context_docs)) | |
for j in range(i, batch_end): | |
doc = context_docs[j] | |
# Ensure content is properly encoded as a string | |
content = doc['content'] | |
if not isinstance(content, str): | |
try: | |
content = str(content) | |
except UnicodeEncodeError: | |
# If there's an encoding error, try to normalize the text | |
import unicodedata | |
content = unicodedata.normalize('NFKD', str(content)) | |
content_preview = content[:3000] | |
doc_info = f"Source: {doc['metadata'].get('source', 'Unknown')}\n" | |
doc_info += f"Type: {doc['metadata'].get('type', 'Unknown')}\n" | |
doc_info += f"Content Snippet: {content_preview}" | |
formatted_docs.append({"title": f"Document {j+1} (Source: {doc['metadata'].get('source', 'Unknown')})", "snippet": doc_info}) | |
# Force garbage collection after each batch | |
import gc | |
gc.collect() | |
try: | |
response = co.chat( | |
message=query, | |
documents=formatted_docs, | |
model="command-r-plus", | |
temperature=0.3, | |
prompt_truncation='AUTO', | |
preamble="You are an expert AI assistant. Answer the user's question based *only* on the provided document snippets. Cite the source document number (e.g., [Document 1]) when using information from it. If the answer isn't in the documents, state that clearly." | |
) | |
return self.stream_chat_completions(response.text) | |
except Exception as e: | |
print(f"Error during Cohere chat API call: {e}") | |
traceback.print_exc() | |
return "Sorry, I encountered an error while trying to generate a response using the retrieved documents." | |
def main(): | |
try: | |
# Initialize query system | |
query_system = FAISSQuerySystem() # Defaults to 'docs/faiss/' | |
# Interactive query loop | |
print("\n--- FAISS RAG Query System ---") | |
print("Ask questions about the content indexed from web, PDFs, and audio.") | |
print("Type 'exit' or 'quit' to stop.") | |
while True: | |
query = input("\nYour question: ") | |
if query.lower() in ('exit', 'quit'): | |
print("Exiting...") | |
break | |
if not query: | |
continue | |
try: | |
# 1. Search for relevant documents | |
print("Searching for relevant documents...") | |
docs = query_system.search(query, k=5) # Get top 5 results | |
if not docs: | |
print("Could not find relevant documents in the knowledge base.") | |
response = query_system.generate_response(query, []) | |
print("\nResponse:") | |
print("-" * 50) | |
print(response) | |
print("-" * 50) | |
continue | |
print(f"Found {len(docs)} relevant document chunks.") | |
# 2. Generate and display response using RAG | |
print("Generating response based on documents...") | |
response = query_system.generate_response(query, docs) | |
print("\nResponse:") | |
print("-" * 50) | |
print(response) | |
print("-" * 50) | |
# 3. Show sources (optional) | |
print("\nRetrieved Sources (Snippets):") | |
for i, doc in enumerate(docs, 1): | |
print(f"\n--- Source {i} ---") | |
print(f" Score: {doc['score']:.4f}") | |
print(f" Source File: {doc['metadata'].get('source', 'Unknown')}") | |
print(f" Type: {doc['metadata'].get('type', 'Unknown')}") | |
if 'page' in doc['metadata']: | |
print(f" Page (PDF): {doc['metadata']['page']}") | |
print(f" Content: {doc['content'][:250]}...") | |
except Exception as e: | |
print(f"\nAn error occurred while processing your query: {e}") | |
traceback.print_exc() | |
except FileNotFoundError as e: | |
print(f"\nInitialization Error: Could not find necessary index files.") | |
print(f"Details: {e}") | |
print("Please ensure you have run the indexing script first and the 'docs/faiss/' directory contains 'index.faiss' and 'index.pkl'.") | |
except Exception as e: | |
print(f"\nA critical initialization error occurred: {e}") | |
traceback.print_exc() | |
if __name__ == "__main__": | |
main() |