"""Retrieval-augmented generation.""" from collections.abc import AsyncIterator, Iterator from litellm import acompletion, completion, get_model_info # type: ignore[attr-defined] from raglite._config import RAGLiteConfig from raglite._database import Chunk from raglite._litellm import LlamaCppPythonLLM from raglite._search import hybrid_search, rerank_chunks, retrieve_segments from raglite._typing import SearchMethod RAG_SYSTEM_PROMPT = """ You are a friendly and knowledgeable assistant that provides complete and insightful answers. Answer the user's question using only the context below. When responding, you MUST NOT reference the existence of the context, directly or indirectly. Instead, you MUST treat the context as if its contents are entirely part of your working memory. """.strip() def _max_contexts( prompt: str, *, max_contexts: int = 5, context_neighbors: tuple[int, ...] | None = (-1, 1), messages: list[dict[str, str]] | None = None, config: RAGLiteConfig | None = None, ) -> int: """Determine the maximum number of contexts for RAG.""" # If the user has configured a llama-cpp-python model, we ensure that LiteLLM's model info is up # to date by loading that LLM. config = config or RAGLiteConfig() if config.llm.startswith("llama-cpp-python"): _ = LlamaCppPythonLLM.llm(config.llm) # Get the model's maximum context size. llm_provider = "llama-cpp-python" if config.llm.startswith("llama-cpp") else None model_info = get_model_info(config.llm, custom_llm_provider=llm_provider) max_tokens = model_info.get("max_tokens") or 2048 # Reduce the maximum number of contexts to take into account the LLM's context size. max_context_tokens = ( max_tokens - sum(len(message["content"]) // 3 for message in messages or []) # Previous messages. - len(RAG_SYSTEM_PROMPT) // 3 # System prompt. - len(prompt) // 3 # User prompt. ) max_tokens_per_context = config.chunk_max_size // 3 max_tokens_per_context *= 1 + len(context_neighbors or []) max_contexts = min(max_contexts, max_context_tokens // max_tokens_per_context) if max_contexts <= 0: error_message = "Not enough context tokens available for RAG." raise ValueError(error_message) return max_contexts def _contexts( # noqa: PLR0913 prompt: str, *, max_contexts: int = 5, context_neighbors: tuple[int, ...] | None = (-1, 1), search: SearchMethod | list[str] | list[Chunk] = hybrid_search, messages: list[dict[str, str]] | None = None, config: RAGLiteConfig | None = None, ) -> list[str]: """Retrieve contexts for RAG.""" # Determine the maximum number of contexts. max_contexts = _max_contexts( prompt, max_contexts=max_contexts, context_neighbors=context_neighbors, messages=messages, config=config, ) # Retrieve the top chunks. config = config or RAGLiteConfig() chunks: list[str] | list[Chunk] if callable(search): # If the user has configured a reranker, we retrieve extra contexts to rerank. extra_contexts = 3 * max_contexts if config.reranker else 0 # Retrieve relevant contexts. chunk_ids, _ = search(prompt, num_results=max_contexts + extra_contexts, config=config) # Rerank the relevant contexts. chunks = rerank_chunks(query=prompt, chunk_ids=chunk_ids, config=config) else: # The user has passed a list of chunk_ids or chunks directly. chunks = search # Extend the top contexts with their neighbors and group chunks into contiguous segments. segments = retrieve_segments(chunks[:max_contexts], neighbors=context_neighbors, config=config) return segments def rag( # noqa: PLR0913 prompt: str, *, max_contexts: int = 5, context_neighbors: tuple[int, ...] | None = (-1, 1), search: SearchMethod | list[str] | list[Chunk] = hybrid_search, messages: list[dict[str, str]] | None = None, system_prompt: str = RAG_SYSTEM_PROMPT, config: RAGLiteConfig | None = None, ) -> Iterator[str]: """Retrieval-augmented generation.""" # Get the contexts for RAG as contiguous segments of chunks. config = config or RAGLiteConfig() segments = _contexts( prompt, max_contexts=max_contexts, context_neighbors=context_neighbors, search=search, config=config, ) system_prompt = f"{system_prompt}\n\n" + "\n\n".join( f'\n{segment.strip()}\n' for i, segment in enumerate(segments) ) # Stream the LLM response. stream = completion( model=config.llm, messages=[ *(messages or []), {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ], stream=True, ) for output in stream: token: str = output["choices"][0]["delta"].get("content") or "" yield token async def async_rag( # noqa: PLR0913 prompt: str, *, max_contexts: int = 5, context_neighbors: tuple[int, ...] | None = (-1, 1), search: SearchMethod | list[str] | list[Chunk] = hybrid_search, messages: list[dict[str, str]] | None = None, system_prompt: str = RAG_SYSTEM_PROMPT, config: RAGLiteConfig | None = None, ) -> AsyncIterator[str]: """Retrieval-augmented generation.""" # Get the contexts for RAG as contiguous segments of chunks. config = config or RAGLiteConfig() segments = _contexts( prompt, max_contexts=max_contexts, context_neighbors=context_neighbors, search=search, config=config, ) system_prompt = f"{system_prompt}\n\n" + "\n\n".join( f'\n{segment.strip()}\n' for i, segment in enumerate(segments) ) # Stream the LLM response. async_stream = await acompletion( model=config.llm, messages=[ *(messages or []), {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ], stream=True, ) async for output in async_stream: token: str = output["choices"][0]["delta"].get("content") or "" yield token