Spaces:
Sleeping
Sleeping
| import cohere | |
| import numpy as np | |
| import faiss | |
| import pickle | |
| import os | |
| import traceback # Import traceback for detailed error printing | |
| from dotenv import load_dotenv | |
| from langchain_community.docstore.document import Document | |
| # Corrected import based on the deprecation warning | |
| from langchain_community.docstore.in_memory import InMemoryDocstore | |
| from openai import OpenAI | |
| # Load environment variables | |
| load_dotenv() | |
| cohere_api_key = os.getenv("COHEREAPIKEY") | |
| api_key = os.getenv("DEEPKEY") | |
| # Initialize OpenAI client with minimal configuration | |
| client = OpenAI( | |
| api_key=api_key, | |
| base_url="https://api.deepseek.com/v1" | |
| ) | |
| # Initialize Cohere client | |
| if not cohere_api_key: | |
| raise ValueError("COHERE_API_KEY not found in environment variables") | |
| co = cohere.Client(cohere_api_key) | |
| # --- Custom Cohere Embeddings Class (for query embedding) --- | |
| class CohereEmbeddingsForQuery: | |
| def __init__(self, client): | |
| self.client = client | |
| self.embed_dim = self._get_embed_dim() | |
| def _get_embed_dim(self): | |
| try: | |
| response = self.client.embed( | |
| texts=["test"], model="embed-english-v3.0", input_type="search_query" | |
| ) | |
| return len(response.embeddings[0]) | |
| except Exception as e: | |
| print(f"Warning: Could not determine embedding dimension automatically: {e}. Defaulting to 4096.") | |
| return 4096 | |
| def embed_query(self, text): | |
| try: | |
| # Ensure text is properly encoded as a string | |
| if not isinstance(text, str): | |
| try: | |
| text = str(text) | |
| except UnicodeEncodeError: | |
| # If there's an encoding error, try to normalize the text | |
| import unicodedata | |
| text = unicodedata.normalize('NFKD', str(text)) | |
| response = self.client.embed( | |
| texts=[text], | |
| model="embed-english-v3.0", | |
| input_type="search_query" | |
| ) | |
| if hasattr(response, 'embeddings') and len(response.embeddings) > 0: | |
| return np.array(response.embeddings[0]).astype('float32') | |
| else: | |
| print("Warning: No query embedding found in the response. Returning zero vector.") | |
| return np.zeros(self.embed_dim, dtype=np.float32) | |
| except Exception as e: | |
| print(f"Query embedding error: {e}") | |
| return np.zeros(self.embed_dim, dtype=np.float32) | |
| # --- FAISS Query System --- | |
| class FAISSQuerySystem: | |
| def __init__(self, persist_dir='docs/faiss/'): | |
| self.persist_dir = persist_dir | |
| self.index = None | |
| self.documents = [] # List to hold LangChain Document objects | |
| self.metadata_list = [] # List to hold metadata dictionaries | |
| self.embedding_function = CohereEmbeddingsForQuery(co) # Use the query-specific class | |
| self.load_index() | |
| def stream_chat_completions(self, input_text): | |
| # Ensure input_text is properly encoded as a string | |
| if not isinstance(input_text, str): | |
| try: | |
| input_text = str(input_text) | |
| except UnicodeEncodeError: | |
| # If there's an encoding error, try to normalize the text | |
| import unicodedata | |
| input_text = unicodedata.normalize('NFKD', str(input_text)) | |
| response = client.chat.completions.create( | |
| model="deepseek-chat", | |
| messages=[ | |
| {"role": "system", "content": "Your job is to make text more appealing by adding emojis, formatting, and other enhancements. Do not include any awkward markup though."}, | |
| {"role": "user", "content": input_text}, | |
| ], | |
| stream=False | |
| ) | |
| try: | |
| resp = response.choices[0].message.content.split("\n---")[1] | |
| except: | |
| resp = response.choices[0].message.content | |
| # Extracting just the core content without the extra sections | |
| resp = resp.replace('**', '') # Remove bold formatting | |
| resp = resp.replace('*', '') | |
| return resp | |
| def load_index(self): | |
| """Load the FAISS index and associated document/metadata files""" | |
| faiss_index_path = os.path.join(self.persist_dir, "index.faiss") | |
| pkl_path = os.path.join(self.persist_dir, "index.pkl") | |
| metadata_path = os.path.join(self.persist_dir, "metadata.pkl") | |
| print(f"Loading FAISS index from: {faiss_index_path}") | |
| print(f"Loading docstore info from: {pkl_path}") | |
| print(f"Loading separate metadata from: {metadata_path}") | |
| if not os.path.exists(faiss_index_path) or not os.path.exists(pkl_path): | |
| raise FileNotFoundError(f"Required index files (index.faiss, index.pkl) not found in {self.persist_dir}") | |
| try: | |
| # 1. Load FAISS index | |
| self.index = faiss.read_index(faiss_index_path) | |
| print(f"FAISS index loaded successfully with {self.index.ntotal} vectors.") | |
| # 2. Load LangChain docstore pickle file | |
| with open(pkl_path, 'rb') as f: | |
| try: | |
| docstore, index_to_docstore_id = pickle.load(f) | |
| except (KeyError, AttributeError) as e: | |
| print(f"Error loading pickle file: {str(e)}") | |
| print("This might be due to a Pydantic version mismatch.") | |
| print("Attempting to recreate the index...") | |
| # Delete the incompatible files | |
| if os.path.exists(faiss_index_path): | |
| os.remove(faiss_index_path) | |
| if os.path.exists(pkl_path): | |
| os.remove(pkl_path) | |
| if os.path.exists(metadata_path): | |
| os.remove(metadata_path) | |
| # Recreate the index | |
| from test import main as recreate_index | |
| recreate_index() | |
| # Try loading again | |
| with open(pkl_path, 'rb') as f: | |
| docstore, index_to_docstore_id = pickle.load(f) | |
| except UnicodeDecodeError: | |
| print("Unicode decode error when loading pickle file. Attempting to handle special characters...") | |
| # Try to handle the Unicode decode error | |
| import codecs | |
| with codecs.open(pkl_path, 'rb', encoding='utf-8', errors='replace') as f: | |
| docstore, index_to_docstore_id = pickle.load(f) | |
| # Verify the types after loading | |
| print(f"Docstore object loaded. Type: {type(docstore)}") | |
| print(f"Index-to-ID mapping loaded. Type: {type(index_to_docstore_id)}") | |
| # Now this line should work | |
| if isinstance(index_to_docstore_id, dict): | |
| print(f"Mapping contains {len(index_to_docstore_id)} entries.") | |
| else: | |
| # This case should ideally not happen now, but good to have a check | |
| raise TypeError(f"Expected index_to_docstore_id to be a dict, but got {type(index_to_docstore_id)}") | |
| if not isinstance(docstore, InMemoryDocstore): | |
| # Add a check for the docstore type too | |
| print(f"Warning: Expected docstore to be InMemoryDocstore, but got {type(docstore)}") | |
| # 3. Reconstruct the list of documents in FAISS index order | |
| self.documents = [] | |
| num_vectors = self.index.ntotal | |
| # Verify consistency | |
| if num_vectors != len(index_to_docstore_id): | |
| print(f"Warning: FAISS index size ({num_vectors}) does not match mapping size ({len(index_to_docstore_id)}). Reconstruction might be incomplete.") | |
| print("Reconstructing document list...") | |
| reconstructed_count = 0 | |
| missing_in_mapping = 0 | |
| missing_in_docstore = 0 | |
| # Ensure docstore has the 'search' method needed. | |
| if not hasattr(docstore, 'search'): | |
| raise AttributeError(f"Loaded docstore object (type: {type(docstore)}) does not have a 'search' method.") | |
| for i in range(num_vectors): | |
| docstore_id = index_to_docstore_id.get(i) | |
| if docstore_id: | |
| # Use the correct method for InMemoryDocstore to retrieve by ID | |
| doc = docstore.search(docstore_id) | |
| if doc: | |
| self.documents.append(doc) | |
| reconstructed_count += 1 | |
| else: | |
| print(f"Warning: Document with ID '{docstore_id}' (for FAISS index {i}) not found in the loaded docstore.") | |
| missing_in_docstore += 1 | |
| else: | |
| print(f"Warning: No docstore ID found in mapping for FAISS index {i}.") | |
| missing_in_mapping += 1 | |
| print(f"Successfully reconstructed {reconstructed_count} documents.") | |
| if missing_in_mapping > 0: print(f"Could not find mapping for {missing_in_mapping} indices.") | |
| if missing_in_docstore > 0: print(f"Could not find {missing_in_docstore} documents in docstore despite having mapping.") | |
| # 4. Load the separate metadata list | |
| if os.path.exists(metadata_path): | |
| with open(metadata_path, 'rb') as f: | |
| self.metadata_list = pickle.load(f) | |
| print(f"Loaded separate metadata list with {len(self.metadata_list)} entries.") | |
| if len(self.metadata_list) != len(self.documents): | |
| print(f"Warning: Mismatch between reconstructed documents ({len(self.documents)}) and loaded metadata list ({len(self.metadata_list)}).") | |
| print("Falling back to using metadata attached to Document objects if available.") | |
| self.metadata_list = [getattr(doc, 'metadata', {}) for doc in self.documents] | |
| elif not self.documents and self.metadata_list: | |
| print("Warning: Loaded metadata but no documents were reconstructed. Discarding metadata.") | |
| self.metadata_list = [] | |
| else: | |
| print("Warning: Separate metadata file (metadata.pkl) not found.") | |
| print("Attempting to use metadata attached to Document objects.") | |
| self.metadata_list = [getattr(doc, 'metadata', {}) for doc in self.documents] | |
| print(f"Final document count: {len(self.documents)}") | |
| print(f"Final metadata count: {len(self.metadata_list)}") | |
| except FileNotFoundError as e: | |
| print(f"Error loading index files: {e}") | |
| raise | |
| except Exception as e: | |
| print(f"An unexpected error occurred during index loading: {e}") | |
| traceback.print_exc() | |
| raise | |
| def search(self, query, k=3): | |
| """Search the index and return relevant documents with metadata and scores""" | |
| if not self.index or self.index.ntotal == 0: | |
| print("Warning: FAISS index is not loaded or is empty.") | |
| return [] | |
| if not self.documents: | |
| print("Warning: No documents were successfully loaded.") | |
| return [] | |
| actual_k = min(k, len(self.documents)) | |
| if actual_k == 0: | |
| return [] | |
| # Ensure query is properly encoded as a string | |
| if not isinstance(query, str): | |
| try: | |
| query = str(query) | |
| except UnicodeEncodeError: | |
| # If there's an encoding error, try to normalize the text | |
| import unicodedata | |
| query = unicodedata.normalize('NFKD', str(query)) | |
| query_embedding = self.embedding_function.embed_query(query) | |
| if np.all(query_embedding == 0): | |
| print("Warning: Query embedding failed, search may be ineffective.") | |
| query_embedding_batch = np.array([query_embedding]) | |
| distances, indices = self.index.search(query_embedding_batch, actual_k) | |
| results = [] | |
| retrieved_indices = indices[0] | |
| for i, idx in enumerate(retrieved_indices): | |
| if idx == -1: | |
| continue | |
| if idx < len(self.documents): | |
| doc = self.documents[idx] | |
| metadata = self.metadata_list[idx] if idx < len(self.metadata_list) else getattr(doc, 'metadata', {}) | |
| distance = distances[0][i] | |
| similarity_score = 1.0 / (1.0 + distance) # Basic L2 -> Similarity | |
| # Ensure content is properly encoded as a string | |
| content = getattr(doc, 'page_content', str(doc)) | |
| if not isinstance(content, str): | |
| try: | |
| content = str(content) | |
| except UnicodeEncodeError: | |
| # If there's an encoding error, try to normalize the text | |
| import unicodedata | |
| content = unicodedata.normalize('NFKD', str(content)) | |
| results.append({ | |
| "content": content, | |
| "metadata": metadata, | |
| "score": float(similarity_score) | |
| }) | |
| else: | |
| print(f"Warning: Search returned index {idx} which is out of bounds for loaded documents ({len(self.documents)}).") | |
| results.sort(key=lambda x: x['score'], reverse=True) | |
| return results | |
| def generate_response(self, query, context_docs): | |
| """Generate RAG response using Cohere's chat API""" | |
| if not context_docs: | |
| print("No context documents provided to generate_response.") | |
| try: | |
| response = co.chat( | |
| message=f"I could not find relevant documents in my knowledge base to answer your question: '{query}'. Please try rephrasing or asking about topics covered in the source material.", | |
| model="command-r-plus", | |
| temperature=0.3, | |
| preamble="You are an AI assistant explaining limitations." | |
| ) | |
| return response.text | |
| except Exception as e: | |
| print(f"Error calling Cohere even without documents: {e}") | |
| return "I could not find relevant documents and encountered an error trying to respond." | |
| formatted_docs = [] | |
| # Process documents in batches to reduce memory usage | |
| batch_size = 3 | |
| for i in range(0, len(context_docs), batch_size): | |
| batch_end = min(i + batch_size, len(context_docs)) | |
| for j in range(i, batch_end): | |
| doc = context_docs[j] | |
| # Ensure content is properly encoded as a string | |
| content = doc['content'] | |
| if not isinstance(content, str): | |
| try: | |
| content = str(content) | |
| except UnicodeEncodeError: | |
| # If there's an encoding error, try to normalize the text | |
| import unicodedata | |
| content = unicodedata.normalize('NFKD', str(content)) | |
| content_preview = content[:3000] | |
| doc_info = f"Source: {doc['metadata'].get('source', 'Unknown')}\n" | |
| doc_info += f"Type: {doc['metadata'].get('type', 'Unknown')}\n" | |
| doc_info += f"Content Snippet: {content_preview}" | |
| formatted_docs.append({"title": f"Document {j+1} (Source: {doc['metadata'].get('source', 'Unknown')})", "snippet": doc_info}) | |
| # Force garbage collection after each batch | |
| import gc | |
| gc.collect() | |
| try: | |
| response = co.chat( | |
| message=query, | |
| documents=formatted_docs, | |
| model="command-r-plus", | |
| temperature=0.3, | |
| prompt_truncation='AUTO', | |
| preamble="You are an expert AI assistant. Answer the user's question based *only* on the provided document snippets. Cite the source document number (e.g., [Document 1]) when using information from it. If the answer isn't in the documents, state that clearly." | |
| ) | |
| return self.stream_chat_completions(response.text) | |
| except Exception as e: | |
| print(f"Error during Cohere chat API call: {e}") | |
| traceback.print_exc() | |
| return "Sorry, I encountered an error while trying to generate a response using the retrieved documents." | |
| def main(): | |
| try: | |
| # Initialize query system | |
| query_system = FAISSQuerySystem() # Defaults to 'docs/faiss/' | |
| # Interactive query loop | |
| print("\n--- FAISS RAG Query System ---") | |
| print("Ask questions about the content indexed from web, PDFs, and audio.") | |
| print("Type 'exit' or 'quit' to stop.") | |
| while True: | |
| query = input("\nYour question: ") | |
| if query.lower() in ('exit', 'quit'): | |
| print("Exiting...") | |
| break | |
| if not query: | |
| continue | |
| try: | |
| # 1. Search for relevant documents | |
| print("Searching for relevant documents...") | |
| docs = query_system.search(query, k=5) # Get top 5 results | |
| if not docs: | |
| print("Could not find relevant documents in the knowledge base.") | |
| response = query_system.generate_response(query, []) | |
| print("\nResponse:") | |
| print("-" * 50) | |
| print(response) | |
| print("-" * 50) | |
| continue | |
| print(f"Found {len(docs)} relevant document chunks.") | |
| # 2. Generate and display response using RAG | |
| print("Generating response based on documents...") | |
| response = query_system.generate_response(query, docs) | |
| print("\nResponse:") | |
| print("-" * 50) | |
| print(response) | |
| print("-" * 50) | |
| # 3. Show sources (optional) | |
| print("\nRetrieved Sources (Snippets):") | |
| for i, doc in enumerate(docs, 1): | |
| print(f"\n--- Source {i} ---") | |
| print(f" Score: {doc['score']:.4f}") | |
| print(f" Source File: {doc['metadata'].get('source', 'Unknown')}") | |
| print(f" Type: {doc['metadata'].get('type', 'Unknown')}") | |
| if 'page' in doc['metadata']: | |
| print(f" Page (PDF): {doc['metadata']['page']}") | |
| print(f" Content: {doc['content'][:250]}...") | |
| except Exception as e: | |
| print(f"\nAn error occurred while processing your query: {e}") | |
| traceback.print_exc() | |
| except FileNotFoundError as e: | |
| print(f"\nInitialization Error: Could not find necessary index files.") | |
| print(f"Details: {e}") | |
| print("Please ensure you have run the indexing script first and the 'docs/faiss/' directory contains 'index.faiss' and 'index.pkl'.") | |
| except Exception as e: | |
| print(f"\nA critical initialization error occurred: {e}") | |
| traceback.print_exc() | |
| if __name__ == "__main__": | |
| main() |