Spaces:
Paused
Paused
| """Chainlit frontend for RAGLite.""" | |
| import os | |
| from pathlib import Path | |
| import chainlit as cl | |
| from chainlit.input_widget import Switch, TextInput | |
| from raglite import ( | |
| RAGLiteConfig, | |
| async_rag, | |
| hybrid_search, | |
| insert_document, | |
| rerank_chunks, | |
| retrieve_chunks, | |
| ) | |
| from raglite._markdown import document_to_markdown | |
| async_insert_document = cl.make_async(insert_document) | |
| async_hybrid_search = cl.make_async(hybrid_search) | |
| async_retrieve_chunks = cl.make_async(retrieve_chunks) | |
| async_rerank_chunks = cl.make_async(rerank_chunks) | |
| async def start_chat() -> None: | |
| """Initialize the chat.""" | |
| # Disable tokenizes parallelism to avoid the deadlock warning. | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # Add Chainlit settings with which the user can configure the RAGLite config. | |
| default_config = RAGLiteConfig() | |
| config = RAGLiteConfig( | |
| db_url=os.environ.get("RAGLITE_DB_URL", default_config.db_url), | |
| llm=os.environ.get("RAGLITE_LLM", default_config.llm), | |
| embedder=os.environ.get("RAGLITE_EMBEDDER", default_config.embedder), | |
| ) | |
| settings = await cl.ChatSettings( # type: ignore[no-untyped-call] | |
| [ | |
| TextInput(id="db_url", label="Database URL", initial=str(config.db_url)), | |
| TextInput(id="llm", label="LLM", initial=config.llm), | |
| TextInput(id="embedder", label="Embedder", initial=config.embedder), | |
| Switch(id="vector_search_query_adapter", label="Query adapter", initial=True), | |
| ] | |
| ).send() | |
| await update_config(settings) | |
| # type: ignore[arg-type] | |
| async def update_config(settings: cl.ChatSettings) -> None: | |
| """Update the RAGLite config.""" | |
| # Update the RAGLite config given the Chainlit settings. | |
| config = RAGLiteConfig( | |
| db_url=settings["db_url"], # type: ignore[index] | |
| llm=settings["llm"], # type: ignore[index] | |
| embedder=settings["embedder"], # type: ignore[index] | |
| vector_search_query_adapter=settings["vector_search_query_adapter"], # type: ignore[index] | |
| ) | |
| cl.user_session.set("config", config) # type: ignore[no-untyped-call] | |
| # Run a search to prime the pipeline if it's a local pipeline. | |
| # TODO: Don't do this for SQLite once we switch from PyNNDescent to sqlite-vec. | |
| if str(config.db_url).startswith("sqlite") or config.embedder.startswith("llama-cpp-python"): | |
| # async with cl.Step(name="initialize", type="retrieval"): | |
| query = "Hello world" | |
| chunk_ids, _ = await async_hybrid_search(query=query, config=config) | |
| _ = await async_rerank_chunks(query=query, chunk_ids=chunk_ids, config=config) | |
| async def handle_message(user_message: cl.Message) -> None: | |
| """Respond to a user message.""" | |
| # Get the config and message history from the user session. | |
| config: RAGLiteConfig = cl.user_session.get("config") # type: ignore[no-untyped-call] | |
| # Determine what to do with the attachments. | |
| inline_attachments = [] | |
| for file in user_message.elements: | |
| if file.path: | |
| doc_md = document_to_markdown(Path(file.path)) | |
| if len(doc_md) // 3 <= 5 * (config.chunk_max_size // 3): | |
| # Document is small enough to attach to the context. | |
| inline_attachments.append(f"{Path(file.path).name}:\n\n{doc_md}") | |
| else: | |
| # Document is too large and must be inserted into the database. | |
| async with cl.Step(name="insert", type="run") as step: | |
| step.input = Path(file.path).name | |
| await async_insert_document(Path(file.path), config=config) | |
| # Append any inline attachments to the user prompt. | |
| user_prompt = f"{user_message.content}\n\n" + "\n\n".join( | |
| f'<attachment index="{i}">\n{attachment.strip()}\n</attachment>' | |
| for i, attachment in enumerate(inline_attachments) | |
| ) | |
| # Search for relevant contexts for RAG. | |
| async with cl.Step(name="search", type="retrieval") as step: | |
| step.input = user_message.content | |
| chunk_ids, _ = await async_hybrid_search(query=user_prompt, num_results=10, config=config) | |
| chunks = await async_retrieve_chunks(chunk_ids=chunk_ids, config=config) | |
| step.output = chunks | |
| step.elements = [ # Show the top 3 chunks inline. | |
| cl.Text(content=str(chunk), display="inline") for chunk in chunks[:3] | |
| ] | |
| # Rerank the chunks. | |
| async with cl.Step(name="rerank", type="rerank") as step: | |
| step.input = chunks | |
| chunks = await async_rerank_chunks(query=user_prompt, chunk_ids=chunks, config=config) | |
| step.output = chunks | |
| step.elements = [ # Show the top 3 chunks inline. | |
| cl.Text(content=str(chunk), display="inline") for chunk in chunks[:3] | |
| ] | |
| # Stream the LLM response. | |
| assistant_message = cl.Message(content="") | |
| async for token in async_rag( | |
| prompt=user_prompt, | |
| search=chunks, | |
| messages=cl.chat_context.to_openai()[-5:], # type: ignore[no-untyped-call] | |
| config=config, | |
| ): | |
| await assistant_message.stream_token(token) | |
| await assistant_message.update() # type: ignore[no-untyped-call] | |