Spaces:
Sleeping
Sleeping
File size: 19,353 Bytes
3c7472d a18c77d 3c7472d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 |
import cohere
import numpy as np
import faiss
import pickle
import os
import traceback # Import traceback for detailed error printing
from dotenv import load_dotenv
from langchain_community.docstore.document import Document
# Corrected import based on the deprecation warning
from langchain_community.docstore.in_memory import InMemoryDocstore
from openai import OpenAI
# Load environment variables
load_dotenv()
cohere_api_key = os.getenv("COHEREAPIKEY")
api_key = os.getenv("DEEPKEY")
# Initialize OpenAI client with minimal configuration
client = OpenAI(
api_key=api_key,
base_url="https://api.deepseek.com/v1"
)
# Initialize Cohere client
if not cohere_api_key:
raise ValueError("COHERE_API_KEY not found in environment variables")
co = cohere.Client(cohere_api_key)
# --- Custom Cohere Embeddings Class (for query embedding) ---
class CohereEmbeddingsForQuery:
def __init__(self, client):
self.client = client
self.embed_dim = self._get_embed_dim()
def _get_embed_dim(self):
try:
response = self.client.embed(
texts=["test"], model="embed-english-v3.0", input_type="search_query"
)
return len(response.embeddings[0])
except Exception as e:
print(f"Warning: Could not determine embedding dimension automatically: {e}. Defaulting to 4096.")
return 4096
def embed_query(self, text):
try:
# Ensure text is properly encoded as a string
if not isinstance(text, str):
try:
text = str(text)
except UnicodeEncodeError:
# If there's an encoding error, try to normalize the text
import unicodedata
text = unicodedata.normalize('NFKD', str(text))
response = self.client.embed(
texts=[text],
model="embed-english-v3.0",
input_type="search_query"
)
if hasattr(response, 'embeddings') and len(response.embeddings) > 0:
return np.array(response.embeddings[0]).astype('float32')
else:
print("Warning: No query embedding found in the response. Returning zero vector.")
return np.zeros(self.embed_dim, dtype=np.float32)
except Exception as e:
print(f"Query embedding error: {e}")
return np.zeros(self.embed_dim, dtype=np.float32)
# --- FAISS Query System ---
class FAISSQuerySystem:
def __init__(self, persist_dir='docs/faiss/'):
self.persist_dir = persist_dir
self.index = None
self.documents = [] # List to hold LangChain Document objects
self.metadata_list = [] # List to hold metadata dictionaries
self.embedding_function = CohereEmbeddingsForQuery(co) # Use the query-specific class
self.load_index()
def stream_chat_completions(self, input_text):
# Ensure input_text is properly encoded as a string
if not isinstance(input_text, str):
try:
input_text = str(input_text)
except UnicodeEncodeError:
# If there's an encoding error, try to normalize the text
import unicodedata
input_text = unicodedata.normalize('NFKD', str(input_text))
response = client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": "Your job is to make text more appealing by adding emojis, formatting, and other enhancements. Do not include any awkward markup though."},
{"role": "user", "content": input_text},
],
stream=False
)
try:
resp = response.choices[0].message.content.split("\n---")[1]
except:
resp = response.choices[0].message.content
# Extracting just the core content without the extra sections
resp = resp.replace('**', '') # Remove bold formatting
resp = resp.replace('*', '')
return resp
def load_index(self):
"""Load the FAISS index and associated document/metadata files"""
faiss_index_path = os.path.join(self.persist_dir, "index.faiss")
pkl_path = os.path.join(self.persist_dir, "index.pkl")
metadata_path = os.path.join(self.persist_dir, "metadata.pkl")
print(f"Loading FAISS index from: {faiss_index_path}")
print(f"Loading docstore info from: {pkl_path}")
print(f"Loading separate metadata from: {metadata_path}")
if not os.path.exists(faiss_index_path) or not os.path.exists(pkl_path):
raise FileNotFoundError(f"Required index files (index.faiss, index.pkl) not found in {self.persist_dir}")
try:
# 1. Load FAISS index
self.index = faiss.read_index(faiss_index_path)
print(f"FAISS index loaded successfully with {self.index.ntotal} vectors.")
# 2. Load LangChain docstore pickle file
with open(pkl_path, 'rb') as f:
try:
docstore, index_to_docstore_id = pickle.load(f)
except (KeyError, AttributeError) as e:
print(f"Error loading pickle file: {str(e)}")
print("This might be due to a Pydantic version mismatch.")
print("Attempting to recreate the index...")
# Delete the incompatible files
if os.path.exists(faiss_index_path):
os.remove(faiss_index_path)
if os.path.exists(pkl_path):
os.remove(pkl_path)
if os.path.exists(metadata_path):
os.remove(metadata_path)
# Recreate the index
from test import main as recreate_index
recreate_index()
# Try loading again
with open(pkl_path, 'rb') as f:
docstore, index_to_docstore_id = pickle.load(f)
except UnicodeDecodeError:
print("Unicode decode error when loading pickle file. Attempting to handle special characters...")
# Try to handle the Unicode decode error
import codecs
with codecs.open(pkl_path, 'rb', encoding='utf-8', errors='replace') as f:
docstore, index_to_docstore_id = pickle.load(f)
# Verify the types after loading
print(f"Docstore object loaded. Type: {type(docstore)}")
print(f"Index-to-ID mapping loaded. Type: {type(index_to_docstore_id)}")
# Now this line should work
if isinstance(index_to_docstore_id, dict):
print(f"Mapping contains {len(index_to_docstore_id)} entries.")
else:
# This case should ideally not happen now, but good to have a check
raise TypeError(f"Expected index_to_docstore_id to be a dict, but got {type(index_to_docstore_id)}")
if not isinstance(docstore, InMemoryDocstore):
# Add a check for the docstore type too
print(f"Warning: Expected docstore to be InMemoryDocstore, but got {type(docstore)}")
# 3. Reconstruct the list of documents in FAISS index order
self.documents = []
num_vectors = self.index.ntotal
# Verify consistency
if num_vectors != len(index_to_docstore_id):
print(f"Warning: FAISS index size ({num_vectors}) does not match mapping size ({len(index_to_docstore_id)}). Reconstruction might be incomplete.")
print("Reconstructing document list...")
reconstructed_count = 0
missing_in_mapping = 0
missing_in_docstore = 0
# Ensure docstore has the 'search' method needed.
if not hasattr(docstore, 'search'):
raise AttributeError(f"Loaded docstore object (type: {type(docstore)}) does not have a 'search' method.")
for i in range(num_vectors):
docstore_id = index_to_docstore_id.get(i)
if docstore_id:
# Use the correct method for InMemoryDocstore to retrieve by ID
doc = docstore.search(docstore_id)
if doc:
self.documents.append(doc)
reconstructed_count += 1
else:
print(f"Warning: Document with ID '{docstore_id}' (for FAISS index {i}) not found in the loaded docstore.")
missing_in_docstore += 1
else:
print(f"Warning: No docstore ID found in mapping for FAISS index {i}.")
missing_in_mapping += 1
print(f"Successfully reconstructed {reconstructed_count} documents.")
if missing_in_mapping > 0: print(f"Could not find mapping for {missing_in_mapping} indices.")
if missing_in_docstore > 0: print(f"Could not find {missing_in_docstore} documents in docstore despite having mapping.")
# 4. Load the separate metadata list
if os.path.exists(metadata_path):
with open(metadata_path, 'rb') as f:
self.metadata_list = pickle.load(f)
print(f"Loaded separate metadata list with {len(self.metadata_list)} entries.")
if len(self.metadata_list) != len(self.documents):
print(f"Warning: Mismatch between reconstructed documents ({len(self.documents)}) and loaded metadata list ({len(self.metadata_list)}).")
print("Falling back to using metadata attached to Document objects if available.")
self.metadata_list = [getattr(doc, 'metadata', {}) for doc in self.documents]
elif not self.documents and self.metadata_list:
print("Warning: Loaded metadata but no documents were reconstructed. Discarding metadata.")
self.metadata_list = []
else:
print("Warning: Separate metadata file (metadata.pkl) not found.")
print("Attempting to use metadata attached to Document objects.")
self.metadata_list = [getattr(doc, 'metadata', {}) for doc in self.documents]
print(f"Final document count: {len(self.documents)}")
print(f"Final metadata count: {len(self.metadata_list)}")
except FileNotFoundError as e:
print(f"Error loading index files: {e}")
raise
except Exception as e:
print(f"An unexpected error occurred during index loading: {e}")
traceback.print_exc()
raise
def search(self, query, k=3):
"""Search the index and return relevant documents with metadata and scores"""
if not self.index or self.index.ntotal == 0:
print("Warning: FAISS index is not loaded or is empty.")
return []
if not self.documents:
print("Warning: No documents were successfully loaded.")
return []
actual_k = min(k, len(self.documents))
if actual_k == 0:
return []
# Ensure query is properly encoded as a string
if not isinstance(query, str):
try:
query = str(query)
except UnicodeEncodeError:
# If there's an encoding error, try to normalize the text
import unicodedata
query = unicodedata.normalize('NFKD', str(query))
query_embedding = self.embedding_function.embed_query(query)
if np.all(query_embedding == 0):
print("Warning: Query embedding failed, search may be ineffective.")
query_embedding_batch = np.array([query_embedding])
distances, indices = self.index.search(query_embedding_batch, actual_k)
results = []
retrieved_indices = indices[0]
for i, idx in enumerate(retrieved_indices):
if idx == -1:
continue
if idx < len(self.documents):
doc = self.documents[idx]
metadata = self.metadata_list[idx] if idx < len(self.metadata_list) else getattr(doc, 'metadata', {})
distance = distances[0][i]
similarity_score = 1.0 / (1.0 + distance) # Basic L2 -> Similarity
# Ensure content is properly encoded as a string
content = getattr(doc, 'page_content', str(doc))
if not isinstance(content, str):
try:
content = str(content)
except UnicodeEncodeError:
# If there's an encoding error, try to normalize the text
import unicodedata
content = unicodedata.normalize('NFKD', str(content))
results.append({
"content": content,
"metadata": metadata,
"score": float(similarity_score)
})
else:
print(f"Warning: Search returned index {idx} which is out of bounds for loaded documents ({len(self.documents)}).")
results.sort(key=lambda x: x['score'], reverse=True)
return results
def generate_response(self, query, context_docs):
"""Generate RAG response using Cohere's chat API"""
if not context_docs:
print("No context documents provided to generate_response.")
try:
response = co.chat(
message=f"I could not find relevant documents in my knowledge base to answer your question: '{query}'. Please try rephrasing or asking about topics covered in the source material.",
model="command-r-plus",
temperature=0.3,
preamble="You are an AI assistant explaining limitations."
)
return response.text
except Exception as e:
print(f"Error calling Cohere even without documents: {e}")
return "I could not find relevant documents and encountered an error trying to respond."
formatted_docs = []
# Process documents in batches to reduce memory usage
batch_size = 3
for i in range(0, len(context_docs), batch_size):
batch_end = min(i + batch_size, len(context_docs))
for j in range(i, batch_end):
doc = context_docs[j]
# Ensure content is properly encoded as a string
content = doc['content']
if not isinstance(content, str):
try:
content = str(content)
except UnicodeEncodeError:
# If there's an encoding error, try to normalize the text
import unicodedata
content = unicodedata.normalize('NFKD', str(content))
content_preview = content[:3000]
doc_info = f"Source: {doc['metadata'].get('source', 'Unknown')}\n"
doc_info += f"Type: {doc['metadata'].get('type', 'Unknown')}\n"
doc_info += f"Content Snippet: {content_preview}"
formatted_docs.append({"title": f"Document {j+1} (Source: {doc['metadata'].get('source', 'Unknown')})", "snippet": doc_info})
# Force garbage collection after each batch
import gc
gc.collect()
try:
response = co.chat(
message=query,
documents=formatted_docs,
model="command-r-plus",
temperature=0.3,
prompt_truncation='AUTO',
preamble="You are an expert AI assistant. Answer the user's question based *only* on the provided document snippets. Cite the source document number (e.g., [Document 1]) when using information from it. If the answer isn't in the documents, state that clearly."
)
return self.stream_chat_completions(response.text)
except Exception as e:
print(f"Error during Cohere chat API call: {e}")
traceback.print_exc()
return "Sorry, I encountered an error while trying to generate a response using the retrieved documents."
def main():
try:
# Initialize query system
query_system = FAISSQuerySystem() # Defaults to 'docs/faiss/'
# Interactive query loop
print("\n--- FAISS RAG Query System ---")
print("Ask questions about the content indexed from web, PDFs, and audio.")
print("Type 'exit' or 'quit' to stop.")
while True:
query = input("\nYour question: ")
if query.lower() in ('exit', 'quit'):
print("Exiting...")
break
if not query:
continue
try:
# 1. Search for relevant documents
print("Searching for relevant documents...")
docs = query_system.search(query, k=5) # Get top 5 results
if not docs:
print("Could not find relevant documents in the knowledge base.")
response = query_system.generate_response(query, [])
print("\nResponse:")
print("-" * 50)
print(response)
print("-" * 50)
continue
print(f"Found {len(docs)} relevant document chunks.")
# 2. Generate and display response using RAG
print("Generating response based on documents...")
response = query_system.generate_response(query, docs)
print("\nResponse:")
print("-" * 50)
print(response)
print("-" * 50)
# 3. Show sources (optional)
print("\nRetrieved Sources (Snippets):")
for i, doc in enumerate(docs, 1):
print(f"\n--- Source {i} ---")
print(f" Score: {doc['score']:.4f}")
print(f" Source File: {doc['metadata'].get('source', 'Unknown')}")
print(f" Type: {doc['metadata'].get('type', 'Unknown')}")
if 'page' in doc['metadata']:
print(f" Page (PDF): {doc['metadata']['page']}")
print(f" Content: {doc['content'][:250]}...")
except Exception as e:
print(f"\nAn error occurred while processing your query: {e}")
traceback.print_exc()
except FileNotFoundError as e:
print(f"\nInitialization Error: Could not find necessary index files.")
print(f"Details: {e}")
print("Please ensure you have run the indexing script first and the 'docs/faiss/' directory contains 'index.faiss' and 'index.pkl'.")
except Exception as e:
print(f"\nA critical initialization error occurred: {e}")
traceback.print_exc()
if __name__ == "__main__":
main() |