rag_lite / src /raglite /_eval.py
EL GHAFRAOUI AYOUB
C
54f5afe
"""Generation and evaluation of evals."""
from random import randint
from typing import ClassVar
import numpy as np
import pandas as pd
from pydantic import BaseModel, Field, field_validator
from sqlmodel import Session, func, select
from tqdm.auto import tqdm, trange
from raglite._config import RAGLiteConfig
from raglite._database import Chunk, Document, Eval, create_database_engine
from raglite._extract import extract_with_llm
from raglite._rag import rag
from raglite._search import hybrid_search, retrieve_segments, vector_search
from raglite._typing import SearchMethod
def insert_evals( # noqa: C901
*, num_evals: int = 100, max_contexts_per_eval: int = 20, config: RAGLiteConfig | None = None
) -> None:
"""Generate and insert evals into the database."""
class QuestionResponse(BaseModel):
"""A specific question about the content of a set of document contexts."""
question: str = Field(
...,
description="A specific question about the content of a set of document contexts.",
min_length=1,
)
system_prompt: ClassVar[str] = """
You are given a set of contexts extracted from a document.
You are a subject matter expert on the document's topic.
Your task is to generate a question to quiz other subject matter experts on the information in the provided context.
The question MUST satisfy ALL of the following criteria:
- The question SHOULD integrate as much of the provided context as possible.
- The question MUST NOT be a general or open question, but MUST instead be as specific to the provided context as possible.
- The question MUST be completely answerable using ONLY the information in the provided context, without depending on any background information.
- The question MUST be entirely self-contained and able to be understood in full WITHOUT access to the provided context.
- The question MUST NOT reference the existence of the context, directly or indirectly.
- The question MUST treat the context as if its contents are entirely part of your working memory.
""".strip()
@field_validator("question")
@classmethod
def validate_question(cls, value: str) -> str:
"""Validate the question."""
question = value.strip().lower()
if "context" in question or "document" in question or "question" in question:
raise ValueError
if not question.endswith("?"):
raise ValueError
return value
config = config or RAGLiteConfig()
engine = create_database_engine(config)
with Session(engine) as session:
for _ in trange(num_evals, desc="Generating evals", unit="eval", dynamic_ncols=True):
# Sample a random document from the database.
seed_document = session.exec(select(Document).order_by(func.random()).limit(1)).first()
if seed_document is None:
error_message = "First run `insert_document()` before generating evals."
raise ValueError(error_message)
# Sample a random chunk from that document.
seed_chunk = session.exec(
select(Chunk)
.where(Chunk.document_id == seed_document.id)
.order_by(func.random())
.limit(1)
).first()
if seed_chunk is None:
continue
# Expand the seed chunk into a set of related chunks.
related_chunk_ids, _ = vector_search(
np.mean(seed_chunk.embedding_matrix, axis=0, keepdims=True),
num_results=randint(2, max_contexts_per_eval // 2), # noqa: S311
config=config,
)
related_chunks = retrieve_segments(related_chunk_ids, config=config)
# Extract a question from the seed chunk's related chunks.
try:
question_response = extract_with_llm(
QuestionResponse, related_chunks, config=config
)
except ValueError:
continue
else:
question = question_response.question
# Search for candidate chunks to answer the generated question.
candidate_chunk_ids, _ = hybrid_search(
question, num_results=max_contexts_per_eval, config=config
)
candidate_chunks = [session.get(Chunk, chunk_id) for chunk_id in candidate_chunk_ids]
# Determine which candidate chunks are relevant to answer the generated question.
class ContextEvalResponse(BaseModel):
"""Indicate whether the provided context can be used to answer a given question."""
hit: bool = Field(
...,
description="True if the provided context contains (a part of) the answer to the given question, false otherwise.",
)
system_prompt: ClassVar[str] = f"""
You are given a context extracted from a document.
You are a subject matter expert on the document's topic.
Your task is to answer whether the provided context contains (a part of) the answer to this question: "{question}"
An example of a context that does NOT contain (a part of) the answer is a table of contents.
""".strip()
relevant_chunks = []
for candidate_chunk in tqdm(
candidate_chunks, desc="Evaluating chunks", unit="chunk", dynamic_ncols=True
):
try:
context_eval_response = extract_with_llm(
ContextEvalResponse, str(candidate_chunk), config=config
)
except ValueError: # noqa: PERF203
pass
else:
if context_eval_response.hit:
relevant_chunks.append(candidate_chunk)
if not relevant_chunks:
continue
# Answer the question using the relevant chunks.
class AnswerResponse(BaseModel):
"""Answer a question using the provided context."""
answer: str = Field(
...,
description="A complete answer to the given question using the provided context.",
min_length=1,
)
system_prompt: ClassVar[str] = f"""
You are given a set of contexts extracted from a document.
You are a subject matter expert on the document's topic.
Your task is to generate a complete answer to the following question using the provided context: "{question}"
The answer MUST satisfy ALL of the following criteria:
- The answer MUST integrate as much of the provided context as possible.
- The answer MUST be entirely self-contained and able to be understood in full WITHOUT access to the provided context.
- The answer MUST NOT reference the existence of the context, directly or indirectly.
- The answer MUST treat the context as if its contents are entirely part of your working memory.
""".strip()
try:
answer_response = extract_with_llm(
AnswerResponse,
[str(relevant_chunk) for relevant_chunk in relevant_chunks],
config=config,
)
except ValueError:
continue
else:
answer = answer_response.answer
# Store the eval in the database.
eval_ = Eval.from_chunks(
question=question,
contexts=relevant_chunks,
ground_truth=answer,
)
session.add(eval_)
session.commit()
def answer_evals(
num_evals: int = 100,
search: SearchMethod = hybrid_search,
*,
config: RAGLiteConfig | None = None,
) -> pd.DataFrame:
"""Read evals from the database and answer them with RAG."""
# Read evals from the database.
config = config or RAGLiteConfig()
engine = create_database_engine(config)
with Session(engine) as session:
evals = session.exec(select(Eval).limit(num_evals)).all()
# Answer evals with RAG.
answers: list[str] = []
contexts: list[list[str]] = []
for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True):
response = rag(eval_.question, search=search, config=config)
answer = "".join(response)
answers.append(answer)
chunk_ids, _ = search(eval_.question, config=config)
contexts.append(retrieve_segments(chunk_ids))
# Collect the answered evals.
answered_evals: dict[str, list[str] | list[list[str]]] = {
"question": [eval_.question for eval_ in evals],
"answer": answers,
"contexts": contexts,
"ground_truth": [eval_.ground_truth for eval_ in evals],
"ground_truth_contexts": [eval_.contexts for eval_ in evals],
}
answered_evals_df = pd.DataFrame.from_dict(answered_evals)
return answered_evals_df
def evaluate(
answered_evals: pd.DataFrame | int = 100,
config: RAGLiteConfig | None = None,
) -> pd.DataFrame:
"""Evaluate the performance of a set of answered evals with Ragas."""
try:
from datasets import Dataset
from langchain_community.chat_models import ChatLiteLLM
from langchain_community.embeddings import LlamaCppEmbeddings
from langchain_community.llms import LlamaCpp
from ragas import RunConfig
from ragas import evaluate as ragas_evaluate
from raglite._litellm import LlamaCppPythonLLM
except ImportError as import_error:
error_message = "To use the `evaluate` function, please install the `ragas` extra."
raise ImportError(error_message) from import_error
# Create a set of answered evals if not provided.
config = config or RAGLiteConfig()
answered_evals_df = (
answered_evals
if isinstance(answered_evals, pd.DataFrame)
else answer_evals(num_evals=answered_evals, config=config)
)
# Load the LLM.
if config.llm.startswith("llama-cpp-python"):
llm = LlamaCppPythonLLM().llm(model=config.llm)
lc_llm = LlamaCpp(
model_path=llm.model_path,
n_batch=llm.n_batch,
n_ctx=llm.n_ctx(),
n_gpu_layers=-1,
verbose=llm.verbose,
)
else:
lc_llm = ChatLiteLLM(model=config.llm) # type: ignore[call-arg]
# Load the embedder.
if not config.embedder.startswith("llama-cpp-python"):
error_message = "Currently, only `llama-cpp-python` embedders are supported."
raise NotImplementedError(error_message)
embedder = LlamaCppPythonLLM().llm(model=config.embedder, embedding=True)
lc_embedder = LlamaCppEmbeddings( # type: ignore[call-arg]
model_path=embedder.model_path,
n_batch=embedder.n_batch,
n_ctx=embedder.n_ctx(),
n_gpu_layers=-1,
verbose=embedder.verbose,
)
# Evaluate the answered evals with Ragas.
evaluation_df = ragas_evaluate(
dataset=Dataset.from_pandas(answered_evals_df),
llm=lc_llm,
embeddings=lc_embedder,
run_config=RunConfig(max_workers=1),
).to_pandas()
return evaluation_df