Spaces:
Running
Running
"""Generation and evaluation of evals.""" | |
from random import randint | |
from typing import ClassVar | |
import numpy as np | |
import pandas as pd | |
from pydantic import BaseModel, Field, field_validator | |
from sqlmodel import Session, func, select | |
from tqdm.auto import tqdm, trange | |
from raglite._config import RAGLiteConfig | |
from raglite._database import Chunk, Document, Eval, create_database_engine | |
from raglite._extract import extract_with_llm | |
from raglite._rag import rag | |
from raglite._search import hybrid_search, retrieve_segments, vector_search | |
from raglite._typing import SearchMethod | |
def insert_evals( # noqa: C901 | |
*, num_evals: int = 100, max_contexts_per_eval: int = 20, config: RAGLiteConfig | None = None | |
) -> None: | |
"""Generate and insert evals into the database.""" | |
class QuestionResponse(BaseModel): | |
"""A specific question about the content of a set of document contexts.""" | |
question: str = Field( | |
..., | |
description="A specific question about the content of a set of document contexts.", | |
min_length=1, | |
) | |
system_prompt: ClassVar[str] = """ | |
You are given a set of contexts extracted from a document. | |
You are a subject matter expert on the document's topic. | |
Your task is to generate a question to quiz other subject matter experts on the information in the provided context. | |
The question MUST satisfy ALL of the following criteria: | |
- The question SHOULD integrate as much of the provided context as possible. | |
- The question MUST NOT be a general or open question, but MUST instead be as specific to the provided context as possible. | |
- The question MUST be completely answerable using ONLY the information in the provided context, without depending on any background information. | |
- The question MUST be entirely self-contained and able to be understood in full WITHOUT access to the provided context. | |
- The question MUST NOT reference the existence of the context, directly or indirectly. | |
- The question MUST treat the context as if its contents are entirely part of your working memory. | |
""".strip() | |
def validate_question(cls, value: str) -> str: | |
"""Validate the question.""" | |
question = value.strip().lower() | |
if "context" in question or "document" in question or "question" in question: | |
raise ValueError | |
if not question.endswith("?"): | |
raise ValueError | |
return value | |
config = config or RAGLiteConfig() | |
engine = create_database_engine(config) | |
with Session(engine) as session: | |
for _ in trange(num_evals, desc="Generating evals", unit="eval", dynamic_ncols=True): | |
# Sample a random document from the database. | |
seed_document = session.exec(select(Document).order_by(func.random()).limit(1)).first() | |
if seed_document is None: | |
error_message = "First run `insert_document()` before generating evals." | |
raise ValueError(error_message) | |
# Sample a random chunk from that document. | |
seed_chunk = session.exec( | |
select(Chunk) | |
.where(Chunk.document_id == seed_document.id) | |
.order_by(func.random()) | |
.limit(1) | |
).first() | |
if seed_chunk is None: | |
continue | |
# Expand the seed chunk into a set of related chunks. | |
related_chunk_ids, _ = vector_search( | |
np.mean(seed_chunk.embedding_matrix, axis=0, keepdims=True), | |
num_results=randint(2, max_contexts_per_eval // 2), # noqa: S311 | |
config=config, | |
) | |
related_chunks = retrieve_segments(related_chunk_ids, config=config) | |
# Extract a question from the seed chunk's related chunks. | |
try: | |
question_response = extract_with_llm( | |
QuestionResponse, related_chunks, config=config | |
) | |
except ValueError: | |
continue | |
else: | |
question = question_response.question | |
# Search for candidate chunks to answer the generated question. | |
candidate_chunk_ids, _ = hybrid_search( | |
question, num_results=max_contexts_per_eval, config=config | |
) | |
candidate_chunks = [session.get(Chunk, chunk_id) for chunk_id in candidate_chunk_ids] | |
# Determine which candidate chunks are relevant to answer the generated question. | |
class ContextEvalResponse(BaseModel): | |
"""Indicate whether the provided context can be used to answer a given question.""" | |
hit: bool = Field( | |
..., | |
description="True if the provided context contains (a part of) the answer to the given question, false otherwise.", | |
) | |
system_prompt: ClassVar[str] = f""" | |
You are given a context extracted from a document. | |
You are a subject matter expert on the document's topic. | |
Your task is to answer whether the provided context contains (a part of) the answer to this question: "{question}" | |
An example of a context that does NOT contain (a part of) the answer is a table of contents. | |
""".strip() | |
relevant_chunks = [] | |
for candidate_chunk in tqdm( | |
candidate_chunks, desc="Evaluating chunks", unit="chunk", dynamic_ncols=True | |
): | |
try: | |
context_eval_response = extract_with_llm( | |
ContextEvalResponse, str(candidate_chunk), config=config | |
) | |
except ValueError: # noqa: PERF203 | |
pass | |
else: | |
if context_eval_response.hit: | |
relevant_chunks.append(candidate_chunk) | |
if not relevant_chunks: | |
continue | |
# Answer the question using the relevant chunks. | |
class AnswerResponse(BaseModel): | |
"""Answer a question using the provided context.""" | |
answer: str = Field( | |
..., | |
description="A complete answer to the given question using the provided context.", | |
min_length=1, | |
) | |
system_prompt: ClassVar[str] = f""" | |
You are given a set of contexts extracted from a document. | |
You are a subject matter expert on the document's topic. | |
Your task is to generate a complete answer to the following question using the provided context: "{question}" | |
The answer MUST satisfy ALL of the following criteria: | |
- The answer MUST integrate as much of the provided context as possible. | |
- The answer MUST be entirely self-contained and able to be understood in full WITHOUT access to the provided context. | |
- The answer MUST NOT reference the existence of the context, directly or indirectly. | |
- The answer MUST treat the context as if its contents are entirely part of your working memory. | |
""".strip() | |
try: | |
answer_response = extract_with_llm( | |
AnswerResponse, | |
[str(relevant_chunk) for relevant_chunk in relevant_chunks], | |
config=config, | |
) | |
except ValueError: | |
continue | |
else: | |
answer = answer_response.answer | |
# Store the eval in the database. | |
eval_ = Eval.from_chunks( | |
question=question, | |
contexts=relevant_chunks, | |
ground_truth=answer, | |
) | |
session.add(eval_) | |
session.commit() | |
def answer_evals( | |
num_evals: int = 100, | |
search: SearchMethod = hybrid_search, | |
*, | |
config: RAGLiteConfig | None = None, | |
) -> pd.DataFrame: | |
"""Read evals from the database and answer them with RAG.""" | |
# Read evals from the database. | |
config = config or RAGLiteConfig() | |
engine = create_database_engine(config) | |
with Session(engine) as session: | |
evals = session.exec(select(Eval).limit(num_evals)).all() | |
# Answer evals with RAG. | |
answers: list[str] = [] | |
contexts: list[list[str]] = [] | |
for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True): | |
response = rag(eval_.question, search=search, config=config) | |
answer = "".join(response) | |
answers.append(answer) | |
chunk_ids, _ = search(eval_.question, config=config) | |
contexts.append(retrieve_segments(chunk_ids)) | |
# Collect the answered evals. | |
answered_evals: dict[str, list[str] | list[list[str]]] = { | |
"question": [eval_.question for eval_ in evals], | |
"answer": answers, | |
"contexts": contexts, | |
"ground_truth": [eval_.ground_truth for eval_ in evals], | |
"ground_truth_contexts": [eval_.contexts for eval_ in evals], | |
} | |
answered_evals_df = pd.DataFrame.from_dict(answered_evals) | |
return answered_evals_df | |
def evaluate( | |
answered_evals: pd.DataFrame | int = 100, | |
config: RAGLiteConfig | None = None, | |
) -> pd.DataFrame: | |
"""Evaluate the performance of a set of answered evals with Ragas.""" | |
try: | |
from datasets import Dataset | |
from langchain_community.chat_models import ChatLiteLLM | |
from langchain_community.embeddings import LlamaCppEmbeddings | |
from langchain_community.llms import LlamaCpp | |
from ragas import RunConfig | |
from ragas import evaluate as ragas_evaluate | |
from raglite._litellm import LlamaCppPythonLLM | |
except ImportError as import_error: | |
error_message = "To use the `evaluate` function, please install the `ragas` extra." | |
raise ImportError(error_message) from import_error | |
# Create a set of answered evals if not provided. | |
config = config or RAGLiteConfig() | |
answered_evals_df = ( | |
answered_evals | |
if isinstance(answered_evals, pd.DataFrame) | |
else answer_evals(num_evals=answered_evals, config=config) | |
) | |
# Load the LLM. | |
if config.llm.startswith("llama-cpp-python"): | |
llm = LlamaCppPythonLLM().llm(model=config.llm) | |
lc_llm = LlamaCpp( | |
model_path=llm.model_path, | |
n_batch=llm.n_batch, | |
n_ctx=llm.n_ctx(), | |
n_gpu_layers=-1, | |
verbose=llm.verbose, | |
) | |
else: | |
lc_llm = ChatLiteLLM(model=config.llm) # type: ignore[call-arg] | |
# Load the embedder. | |
if not config.embedder.startswith("llama-cpp-python"): | |
error_message = "Currently, only `llama-cpp-python` embedders are supported." | |
raise NotImplementedError(error_message) | |
embedder = LlamaCppPythonLLM().llm(model=config.embedder, embedding=True) | |
lc_embedder = LlamaCppEmbeddings( # type: ignore[call-arg] | |
model_path=embedder.model_path, | |
n_batch=embedder.n_batch, | |
n_ctx=embedder.n_ctx(), | |
n_gpu_layers=-1, | |
verbose=embedder.verbose, | |
) | |
# Evaluate the answered evals with Ragas. | |
evaluation_df = ragas_evaluate( | |
dataset=Dataset.from_pandas(answered_evals_df), | |
llm=lc_llm, | |
embeddings=lc_embedder, | |
run_config=RunConfig(max_workers=1), | |
).to_pandas() | |
return evaluation_df | |