Spaces:
Runtime error
Runtime error
"""This file should be imported only and only if you want to run the UI locally.""" | |
import itertools | |
import logging | |
from pathlib import Path | |
from typing import Any | |
import gradio as gr | |
from fastapi import FastAPI | |
from gradio.themes.utils.colors import slate | |
from llama_index.llms import ChatMessage, MessageRole | |
from app._config import settings | |
from app.components.embedding.component import EmbeddingComponent | |
from app.components.llm.component import LLMComponent | |
from app.components.node_store.component import NodeStoreComponent | |
from app.components.vector_store.component import VectorStoreComponent | |
from app.enums import PROJECT_ROOT_PATH | |
from app.server.chat.service import ChatService | |
from app.server.ingest.service import IngestService | |
from app.ui.schemas import Source | |
logger = logging.getLogger(__name__) | |
THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH) | |
AVATAR_BOT = THIS_DIRECTORY_RELATIVE / "dodge_ava.jpg" | |
UI_TAB_TITLE = "Agriculture Chatbot" | |
SOURCES_SEPARATOR = "\n\n Sources: \n" | |
class PrivateGptUi: | |
def __init__( | |
self, | |
ingest_service: IngestService, | |
chat_service: ChatService, | |
) -> None: | |
self._ingest_service = ingest_service | |
self._chat_service = chat_service | |
# Cache the UI blocks | |
self._ui_block = None | |
# Initialize system prompt | |
self._system_prompt = self._get_default_system_prompt() | |
def _chat(self, message: str, history: list[list[str]], *_: Any) -> Any: | |
def build_history() -> list[ChatMessage]: | |
history_messages: list[ChatMessage] = list( | |
itertools.chain( | |
*[ | |
[ | |
ChatMessage(content=interaction[0], role=MessageRole.USER), | |
ChatMessage( | |
# Remove from history content the Sources information | |
content=interaction[1].split(SOURCES_SEPARATOR)[0], | |
role=MessageRole.ASSISTANT, | |
), | |
] | |
for interaction in history | |
] | |
) | |
) | |
# max 20 messages to try to avoid context overflow | |
return history_messages[:20] | |
new_message = ChatMessage(content=message, role=MessageRole.USER) | |
all_messages = [*build_history(), new_message] | |
# If a system prompt is set, add it as a system message | |
if self._system_prompt: | |
all_messages.insert( | |
0, | |
ChatMessage( | |
content=self._system_prompt, | |
role=MessageRole.SYSTEM, | |
), | |
) | |
completion = self._chat_service.chat(messages=all_messages) | |
full_response = completion.response | |
if completion.sources: | |
full_response += SOURCES_SEPARATOR | |
curated_sources = Source.curate_sources(completion.sources) | |
sources_text = "\n\n\n".join( | |
f"{index}. {source.file} (page {source.page})" | |
for index, source in enumerate(curated_sources, start=1) | |
) | |
full_response += sources_text | |
return full_response | |
# On initialization this function set the system prompt | |
# to the default prompt based on settings. | |
def _get_default_system_prompt() -> str: | |
return settings.DEFAULT_QUERY_SYSTEM_PROMPT | |
def _set_system_prompt(self, system_prompt_input: str) -> None: | |
logger.info(f"Setting system prompt to: {system_prompt_input}") | |
self._system_prompt = system_prompt_input | |
def _list_ingested_files(self) -> list[list[str]]: | |
files = set() | |
for ingested_document in self._ingest_service.list_ingested(): | |
if ingested_document.doc_metadata is None: | |
# Skipping documents without metadata | |
continue | |
file_name = ingested_document.doc_metadata.get( | |
"file_name", "[FILE NAME MISSING]" | |
) | |
files.add(file_name) | |
return [[row] for row in files] | |
def _upload_file(self, files: list[str]) -> None: | |
logger.debug("Loading count=%s files", len(files)) | |
paths = [Path(file) for file in files] | |
self._ingest_service.bulk_ingest([(str(path.name), path) for path in paths]) | |
def _build_ui_blocks(self) -> gr.Blocks: | |
logger.debug("Creating the UI blocks") | |
with gr.Blocks( | |
title=UI_TAB_TITLE, | |
theme=gr.themes.Soft(primary_hue=slate), | |
css=".logo { " | |
"display:flex;" | |
"height: 80px;" | |
"border-radius: 8px;" | |
"align-content: center;" | |
"justify-content: center;" | |
"align-items: center;" | |
"}" | |
".logo img { height: 25% }" | |
".contain { display: flex !important; flex-direction: column !important; }" | |
"#component-0, #component-3, #component-10, #component-8 { height: 100% !important; }" | |
"#chatbot { flex-grow: 1 !important; overflow: auto !important;}" | |
"#col { height: calc(100vh - 112px - 16px) !important; }", | |
) as blocks: | |
with gr.Row(): | |
gr.HTML(f"<div class='logo'/><h1>{UI_TAB_TITLE}</h1></div") | |
with gr.Row(equal_height=False): | |
with gr.Column(scale=3): | |
upload_button = gr.components.UploadButton( | |
"Upload File(s)", | |
type="filepath", | |
file_count="multiple", | |
size="sm", | |
) | |
ingested_dataset = gr.List( | |
self._list_ingested_files, | |
headers=["File name"], | |
label="Ingested Files", | |
interactive=False, | |
render=False, # Rendered under the button | |
) | |
upload_button.upload( | |
self._upload_file, | |
inputs=upload_button, | |
outputs=ingested_dataset, | |
) | |
ingested_dataset.change( | |
self._list_ingested_files, | |
outputs=ingested_dataset, | |
) | |
ingested_dataset.render() | |
system_prompt_input = gr.Textbox( | |
placeholder=self._system_prompt, | |
label="System Prompt", | |
lines=2, | |
interactive=True, | |
render=False, | |
) | |
# On blur, set system prompt to use in queries | |
system_prompt_input.blur( | |
self._set_system_prompt, | |
inputs=system_prompt_input, | |
) | |
with gr.Column(scale=7, elem_id="col"): | |
_ = gr.ChatInterface( | |
self._chat, | |
chatbot=gr.Chatbot( | |
label=f"LLM: {settings.LLM_MODE}", | |
show_copy_button=True, | |
elem_id="chatbot", | |
render=False, | |
avatar_images=( | |
None, | |
AVATAR_BOT, | |
), | |
), | |
additional_inputs=[upload_button, system_prompt_input], | |
) | |
return blocks | |
def get_ui_blocks(self) -> gr.Blocks: | |
if self._ui_block is None: | |
self._ui_block = self._build_ui_blocks() | |
return self._ui_block | |
def mount_in_app(self, app: FastAPI, path: str) -> None: | |
blocks = self.get_ui_blocks() | |
blocks.queue() | |
logger.info("Mounting the gradio UI, at path=%s", path) | |
gr.mount_gradio_app(app, blocks, path=path) | |
if __name__ == "__main__": | |
llm_component = LLMComponent() | |
vector_store_component = VectorStoreComponent() | |
embedding_component = EmbeddingComponent() | |
node_store_component = NodeStoreComponent() | |
ingest_service = IngestService( | |
llm_component, vector_store_component, embedding_component, node_store_component | |
) | |
chat_service = ChatService( | |
llm_component, vector_store_component, embedding_component, node_store_component | |
) | |
ui = PrivateGptUi(ingest_service, chat_service) | |
_blocks = ui.get_ui_blocks() | |
_blocks.queue() | |
_blocks.launch(debug=False, show_api=False) | |