discord-bot / ui /ui.py
khointn's picture
Upload folder using huggingface_hub
5a67683 verified
"""This file should be imported only and only if you want to run the UI locally."""
import itertools
import logging
from pathlib import Path
from typing import Any
import gradio as gr
from fastapi import FastAPI
from gradio.themes.utils.colors import slate
from llama_index.llms import ChatMessage, MessageRole
from app._config import settings
from app.components.embedding.component import EmbeddingComponent
from app.components.llm.component import LLMComponent
from app.components.node_store.component import NodeStoreComponent
from app.components.vector_store.component import VectorStoreComponent
from app.enums import PROJECT_ROOT_PATH
from app.server.chat.service import ChatService
from app.server.ingest.service import IngestService
from app.ui.schemas import Source
logger = logging.getLogger(__name__)
THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH)
AVATAR_BOT = THIS_DIRECTORY_RELATIVE / "dodge_ava.jpg"
UI_TAB_TITLE = "Agriculture Chatbot"
SOURCES_SEPARATOR = "\n\n Sources: \n"
class PrivateGptUi:
def __init__(
self,
ingest_service: IngestService,
chat_service: ChatService,
) -> None:
self._ingest_service = ingest_service
self._chat_service = chat_service
# Cache the UI blocks
self._ui_block = None
# Initialize system prompt
self._system_prompt = self._get_default_system_prompt()
def _chat(self, message: str, history: list[list[str]], *_: Any) -> Any:
def build_history() -> list[ChatMessage]:
history_messages: list[ChatMessage] = list(
itertools.chain(
*[
[
ChatMessage(content=interaction[0], role=MessageRole.USER),
ChatMessage(
# Remove from history content the Sources information
content=interaction[1].split(SOURCES_SEPARATOR)[0],
role=MessageRole.ASSISTANT,
),
]
for interaction in history
]
)
)
# max 20 messages to try to avoid context overflow
return history_messages[:20]
new_message = ChatMessage(content=message, role=MessageRole.USER)
all_messages = [*build_history(), new_message]
# If a system prompt is set, add it as a system message
if self._system_prompt:
all_messages.insert(
0,
ChatMessage(
content=self._system_prompt,
role=MessageRole.SYSTEM,
),
)
completion = self._chat_service.chat(messages=all_messages)
full_response = completion.response
if completion.sources:
full_response += SOURCES_SEPARATOR
curated_sources = Source.curate_sources(completion.sources)
sources_text = "\n\n\n".join(
f"{index}. {source.file} (page {source.page})"
for index, source in enumerate(curated_sources, start=1)
)
full_response += sources_text
return full_response
# On initialization this function set the system prompt
# to the default prompt based on settings.
@staticmethod
def _get_default_system_prompt() -> str:
return settings.DEFAULT_QUERY_SYSTEM_PROMPT
def _set_system_prompt(self, system_prompt_input: str) -> None:
logger.info(f"Setting system prompt to: {system_prompt_input}")
self._system_prompt = system_prompt_input
def _list_ingested_files(self) -> list[list[str]]:
files = set()
for ingested_document in self._ingest_service.list_ingested():
if ingested_document.doc_metadata is None:
# Skipping documents without metadata
continue
file_name = ingested_document.doc_metadata.get(
"file_name", "[FILE NAME MISSING]"
)
files.add(file_name)
return [[row] for row in files]
def _upload_file(self, files: list[str]) -> None:
logger.debug("Loading count=%s files", len(files))
paths = [Path(file) for file in files]
self._ingest_service.bulk_ingest([(str(path.name), path) for path in paths])
def _build_ui_blocks(self) -> gr.Blocks:
logger.debug("Creating the UI blocks")
with gr.Blocks(
title=UI_TAB_TITLE,
theme=gr.themes.Soft(primary_hue=slate),
css=".logo { "
"display:flex;"
"height: 80px;"
"border-radius: 8px;"
"align-content: center;"
"justify-content: center;"
"align-items: center;"
"}"
".logo img { height: 25% }"
".contain { display: flex !important; flex-direction: column !important; }"
"#component-0, #component-3, #component-10, #component-8 { height: 100% !important; }"
"#chatbot { flex-grow: 1 !important; overflow: auto !important;}"
"#col { height: calc(100vh - 112px - 16px) !important; }",
) as blocks:
with gr.Row():
gr.HTML(f"<div class='logo'/><h1>{UI_TAB_TITLE}</h1></div")
with gr.Row(equal_height=False):
with gr.Column(scale=3):
upload_button = gr.components.UploadButton(
"Upload File(s)",
type="filepath",
file_count="multiple",
size="sm",
)
ingested_dataset = gr.List(
self._list_ingested_files,
headers=["File name"],
label="Ingested Files",
interactive=False,
render=False, # Rendered under the button
)
upload_button.upload(
self._upload_file,
inputs=upload_button,
outputs=ingested_dataset,
)
ingested_dataset.change(
self._list_ingested_files,
outputs=ingested_dataset,
)
ingested_dataset.render()
system_prompt_input = gr.Textbox(
placeholder=self._system_prompt,
label="System Prompt",
lines=2,
interactive=True,
render=False,
)
# On blur, set system prompt to use in queries
system_prompt_input.blur(
self._set_system_prompt,
inputs=system_prompt_input,
)
with gr.Column(scale=7, elem_id="col"):
_ = gr.ChatInterface(
self._chat,
chatbot=gr.Chatbot(
label=f"LLM: {settings.LLM_MODE}",
show_copy_button=True,
elem_id="chatbot",
render=False,
avatar_images=(
None,
AVATAR_BOT,
),
),
additional_inputs=[upload_button, system_prompt_input],
)
return blocks
def get_ui_blocks(self) -> gr.Blocks:
if self._ui_block is None:
self._ui_block = self._build_ui_blocks()
return self._ui_block
def mount_in_app(self, app: FastAPI, path: str) -> None:
blocks = self.get_ui_blocks()
blocks.queue()
logger.info("Mounting the gradio UI, at path=%s", path)
gr.mount_gradio_app(app, blocks, path=path)
if __name__ == "__main__":
llm_component = LLMComponent()
vector_store_component = VectorStoreComponent()
embedding_component = EmbeddingComponent()
node_store_component = NodeStoreComponent()
ingest_service = IngestService(
llm_component, vector_store_component, embedding_component, node_store_component
)
chat_service = ChatService(
llm_component, vector_store_component, embedding_component, node_store_component
)
ui = PrivateGptUi(ingest_service, chat_service)
_blocks = ui.get_ui_blocks()
_blocks.queue()
_blocks.launch(debug=False, show_api=False)