|
""" |
|
Retriever module for Norwegian RAG chatbot. |
|
Retrieves relevant document chunks based on query embeddings. |
|
""" |
|
|
|
import os |
|
import json |
|
import numpy as np |
|
from typing import List, Dict, Any, Optional, Tuple, Union |
|
|
|
from ..api.huggingface_api import HuggingFaceAPI |
|
from ..api.config import MAX_CHUNKS_TO_RETRIEVE, SIMILARITY_THRESHOLD |
|
|
|
class Retriever: |
|
""" |
|
Retrieves relevant document chunks based on query embeddings. |
|
Uses cosine similarity to find the most relevant chunks. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
api_client: Optional[HuggingFaceAPI] = None, |
|
processed_dir: str = "/home/ubuntu/chatbot_project/data/processed", |
|
max_chunks: int = MAX_CHUNKS_TO_RETRIEVE, |
|
similarity_threshold: float = SIMILARITY_THRESHOLD |
|
): |
|
""" |
|
Initialize the retriever. |
|
|
|
Args: |
|
api_client: HuggingFaceAPI client for generating embeddings |
|
processed_dir: Directory containing processed documents |
|
max_chunks: Maximum number of chunks to retrieve |
|
similarity_threshold: Minimum similarity score for retrieval |
|
""" |
|
self.api_client = api_client or HuggingFaceAPI() |
|
self.processed_dir = processed_dir |
|
self.max_chunks = max_chunks |
|
self.similarity_threshold = similarity_threshold |
|
|
|
|
|
self.document_index_path = os.path.join(self.processed_dir, "document_index.json") |
|
self.document_index = self._load_document_index() |
|
|
|
def retrieve(self, query: str) -> List[Dict[str, Any]]: |
|
""" |
|
Retrieve relevant document chunks for a query. |
|
|
|
Args: |
|
query: User query |
|
|
|
Returns: |
|
List of retrieved chunks with metadata |
|
""" |
|
|
|
query_embedding = self.api_client.generate_embeddings(query)[0] |
|
|
|
|
|
all_results = [] |
|
|
|
for doc_id in self.document_index: |
|
try: |
|
|
|
doc_results = self._retrieve_from_document(doc_id, query_embedding) |
|
all_results.extend(doc_results) |
|
except Exception as e: |
|
print(f"Error retrieving from document {doc_id}: {str(e)}") |
|
|
|
|
|
all_results.sort(key=lambda x: x["similarity"], reverse=True) |
|
|
|
|
|
return [ |
|
result for result in all_results[:self.max_chunks] |
|
if result["similarity"] >= self.similarity_threshold |
|
] |
|
|
|
def _retrieve_from_document( |
|
self, |
|
document_id: str, |
|
query_embedding: List[float] |
|
) -> List[Dict[str, Any]]: |
|
""" |
|
Retrieve relevant chunks from a specific document. |
|
|
|
Args: |
|
document_id: Document ID |
|
query_embedding: Query embedding vector |
|
|
|
Returns: |
|
List of retrieved chunks with metadata |
|
""" |
|
document_path = os.path.join(self.processed_dir, f"{document_id}.json") |
|
if not os.path.exists(document_path): |
|
return [] |
|
|
|
|
|
with open(document_path, 'r', encoding='utf-8') as f: |
|
document_data = json.load(f) |
|
|
|
chunks = document_data.get("chunks", []) |
|
embeddings = document_data.get("embeddings", []) |
|
metadata = document_data.get("metadata", {}) |
|
|
|
if not chunks or not embeddings or len(chunks) != len(embeddings): |
|
return [] |
|
|
|
|
|
results = [] |
|
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): |
|
similarity = self._cosine_similarity(query_embedding, embedding) |
|
|
|
results.append({ |
|
"document_id": document_id, |
|
"chunk_index": i, |
|
"chunk_text": chunk, |
|
"similarity": similarity, |
|
"metadata": metadata |
|
}) |
|
|
|
|
|
results.sort(key=lambda x: x["similarity"], reverse=True) |
|
|
|
return results[:self.max_chunks] |
|
|
|
def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float: |
|
""" |
|
Calculate cosine similarity between two vectors. |
|
|
|
Args: |
|
vec1: First vector |
|
vec2: Second vector |
|
|
|
Returns: |
|
Cosine similarity score |
|
""" |
|
vec1 = np.array(vec1) |
|
vec2 = np.array(vec2) |
|
|
|
dot_product = np.dot(vec1, vec2) |
|
norm1 = np.linalg.norm(vec1) |
|
norm2 = np.linalg.norm(vec2) |
|
|
|
if norm1 == 0 or norm2 == 0: |
|
return 0.0 |
|
|
|
return dot_product / (norm1 * norm2) |
|
|
|
def _load_document_index(self) -> Dict[str, Dict[str, Any]]: |
|
""" |
|
Load the document index from disk. |
|
|
|
Returns: |
|
Dictionary of document IDs to metadata |
|
""" |
|
if os.path.exists(self.document_index_path): |
|
try: |
|
with open(self.document_index_path, 'r', encoding='utf-8') as f: |
|
return json.load(f) |
|
except Exception as e: |
|
print(f"Error loading document index: {str(e)}") |
|
|
|
return {} |
|
|