iver / src /rag /retriever.py
hevold's picture
Upload 29 files
b34efa5 verified
"""
Retriever module for Norwegian RAG chatbot.
Retrieves relevant document chunks based on query embeddings.
"""
import os
import json
import numpy as np
from typing import List, Dict, Any, Optional, Tuple, Union
from ..api.huggingface_api import HuggingFaceAPI
from ..api.config import MAX_CHUNKS_TO_RETRIEVE, SIMILARITY_THRESHOLD
class Retriever:
"""
Retrieves relevant document chunks based on query embeddings.
Uses cosine similarity to find the most relevant chunks.
"""
def __init__(
self,
api_client: Optional[HuggingFaceAPI] = None,
processed_dir: str = "/home/ubuntu/chatbot_project/data/processed",
max_chunks: int = MAX_CHUNKS_TO_RETRIEVE,
similarity_threshold: float = SIMILARITY_THRESHOLD
):
"""
Initialize the retriever.
Args:
api_client: HuggingFaceAPI client for generating embeddings
processed_dir: Directory containing processed documents
max_chunks: Maximum number of chunks to retrieve
similarity_threshold: Minimum similarity score for retrieval
"""
self.api_client = api_client or HuggingFaceAPI()
self.processed_dir = processed_dir
self.max_chunks = max_chunks
self.similarity_threshold = similarity_threshold
# Load document index
self.document_index_path = os.path.join(self.processed_dir, "document_index.json")
self.document_index = self._load_document_index()
def retrieve(self, query: str) -> List[Dict[str, Any]]:
"""
Retrieve relevant document chunks for a query.
Args:
query: User query
Returns:
List of retrieved chunks with metadata
"""
# Generate embedding for the query
query_embedding = self.api_client.generate_embeddings(query)[0]
# Find relevant chunks across all documents
all_results = []
for doc_id in self.document_index:
try:
# Load document data
doc_results = self._retrieve_from_document(doc_id, query_embedding)
all_results.extend(doc_results)
except Exception as e:
print(f"Error retrieving from document {doc_id}: {str(e)}")
# Sort all results by similarity score
all_results.sort(key=lambda x: x["similarity"], reverse=True)
# Return top results above threshold
return [
result for result in all_results[:self.max_chunks]
if result["similarity"] >= self.similarity_threshold
]
def _retrieve_from_document(
self,
document_id: str,
query_embedding: List[float]
) -> List[Dict[str, Any]]:
"""
Retrieve relevant chunks from a specific document.
Args:
document_id: Document ID
query_embedding: Query embedding vector
Returns:
List of retrieved chunks with metadata
"""
document_path = os.path.join(self.processed_dir, f"{document_id}.json")
if not os.path.exists(document_path):
return []
# Load document data
with open(document_path, 'r', encoding='utf-8') as f:
document_data = json.load(f)
chunks = document_data.get("chunks", [])
embeddings = document_data.get("embeddings", [])
metadata = document_data.get("metadata", {})
if not chunks or not embeddings or len(chunks) != len(embeddings):
return []
# Calculate similarity scores
results = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
similarity = self._cosine_similarity(query_embedding, embedding)
results.append({
"document_id": document_id,
"chunk_index": i,
"chunk_text": chunk,
"similarity": similarity,
"metadata": metadata
})
# Sort by similarity
results.sort(key=lambda x: x["similarity"], reverse=True)
return results[:self.max_chunks]
def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
"""
Calculate cosine similarity between two vectors.
Args:
vec1: First vector
vec2: Second vector
Returns:
Cosine similarity score
"""
vec1 = np.array(vec1)
vec2 = np.array(vec2)
dot_product = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
if norm1 == 0 or norm2 == 0:
return 0.0
return dot_product / (norm1 * norm2)
def _load_document_index(self) -> Dict[str, Dict[str, Any]]:
"""
Load the document index from disk.
Returns:
Dictionary of document IDs to metadata
"""
if os.path.exists(self.document_index_path):
try:
with open(self.document_index_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"Error loading document index: {str(e)}")
return {}