Spaces:
Runtime error
Runtime error
File size: 8,689 Bytes
5a67683 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
"""This file should be imported only and only if you want to run the UI locally."""
import itertools
import logging
from pathlib import Path
from typing import Any
import gradio as gr
from fastapi import FastAPI
from gradio.themes.utils.colors import slate
from llama_index.llms import ChatMessage, MessageRole
from app._config import settings
from app.components.embedding.component import EmbeddingComponent
from app.components.llm.component import LLMComponent
from app.components.node_store.component import NodeStoreComponent
from app.components.vector_store.component import VectorStoreComponent
from app.enums import PROJECT_ROOT_PATH
from app.server.chat.service import ChatService
from app.server.ingest.service import IngestService
from app.ui.schemas import Source
logger = logging.getLogger(__name__)
THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH)
AVATAR_BOT = THIS_DIRECTORY_RELATIVE / "dodge_ava.jpg"
UI_TAB_TITLE = "Agriculture Chatbot"
SOURCES_SEPARATOR = "\n\n Sources: \n"
class PrivateGptUi:
def __init__(
self,
ingest_service: IngestService,
chat_service: ChatService,
) -> None:
self._ingest_service = ingest_service
self._chat_service = chat_service
# Cache the UI blocks
self._ui_block = None
# Initialize system prompt
self._system_prompt = self._get_default_system_prompt()
def _chat(self, message: str, history: list[list[str]], *_: Any) -> Any:
def build_history() -> list[ChatMessage]:
history_messages: list[ChatMessage] = list(
itertools.chain(
*[
[
ChatMessage(content=interaction[0], role=MessageRole.USER),
ChatMessage(
# Remove from history content the Sources information
content=interaction[1].split(SOURCES_SEPARATOR)[0],
role=MessageRole.ASSISTANT,
),
]
for interaction in history
]
)
)
# max 20 messages to try to avoid context overflow
return history_messages[:20]
new_message = ChatMessage(content=message, role=MessageRole.USER)
all_messages = [*build_history(), new_message]
# If a system prompt is set, add it as a system message
if self._system_prompt:
all_messages.insert(
0,
ChatMessage(
content=self._system_prompt,
role=MessageRole.SYSTEM,
),
)
completion = self._chat_service.chat(messages=all_messages)
full_response = completion.response
if completion.sources:
full_response += SOURCES_SEPARATOR
curated_sources = Source.curate_sources(completion.sources)
sources_text = "\n\n\n".join(
f"{index}. {source.file} (page {source.page})"
for index, source in enumerate(curated_sources, start=1)
)
full_response += sources_text
return full_response
# On initialization this function set the system prompt
# to the default prompt based on settings.
@staticmethod
def _get_default_system_prompt() -> str:
return settings.DEFAULT_QUERY_SYSTEM_PROMPT
def _set_system_prompt(self, system_prompt_input: str) -> None:
logger.info(f"Setting system prompt to: {system_prompt_input}")
self._system_prompt = system_prompt_input
def _list_ingested_files(self) -> list[list[str]]:
files = set()
for ingested_document in self._ingest_service.list_ingested():
if ingested_document.doc_metadata is None:
# Skipping documents without metadata
continue
file_name = ingested_document.doc_metadata.get(
"file_name", "[FILE NAME MISSING]"
)
files.add(file_name)
return [[row] for row in files]
def _upload_file(self, files: list[str]) -> None:
logger.debug("Loading count=%s files", len(files))
paths = [Path(file) for file in files]
self._ingest_service.bulk_ingest([(str(path.name), path) for path in paths])
def _build_ui_blocks(self) -> gr.Blocks:
logger.debug("Creating the UI blocks")
with gr.Blocks(
title=UI_TAB_TITLE,
theme=gr.themes.Soft(primary_hue=slate),
css=".logo { "
"display:flex;"
"height: 80px;"
"border-radius: 8px;"
"align-content: center;"
"justify-content: center;"
"align-items: center;"
"}"
".logo img { height: 25% }"
".contain { display: flex !important; flex-direction: column !important; }"
"#component-0, #component-3, #component-10, #component-8 { height: 100% !important; }"
"#chatbot { flex-grow: 1 !important; overflow: auto !important;}"
"#col { height: calc(100vh - 112px - 16px) !important; }",
) as blocks:
with gr.Row():
gr.HTML(f"<div class='logo'/><h1>{UI_TAB_TITLE}</h1></div")
with gr.Row(equal_height=False):
with gr.Column(scale=3):
upload_button = gr.components.UploadButton(
"Upload File(s)",
type="filepath",
file_count="multiple",
size="sm",
)
ingested_dataset = gr.List(
self._list_ingested_files,
headers=["File name"],
label="Ingested Files",
interactive=False,
render=False, # Rendered under the button
)
upload_button.upload(
self._upload_file,
inputs=upload_button,
outputs=ingested_dataset,
)
ingested_dataset.change(
self._list_ingested_files,
outputs=ingested_dataset,
)
ingested_dataset.render()
system_prompt_input = gr.Textbox(
placeholder=self._system_prompt,
label="System Prompt",
lines=2,
interactive=True,
render=False,
)
# On blur, set system prompt to use in queries
system_prompt_input.blur(
self._set_system_prompt,
inputs=system_prompt_input,
)
with gr.Column(scale=7, elem_id="col"):
_ = gr.ChatInterface(
self._chat,
chatbot=gr.Chatbot(
label=f"LLM: {settings.LLM_MODE}",
show_copy_button=True,
elem_id="chatbot",
render=False,
avatar_images=(
None,
AVATAR_BOT,
),
),
additional_inputs=[upload_button, system_prompt_input],
)
return blocks
def get_ui_blocks(self) -> gr.Blocks:
if self._ui_block is None:
self._ui_block = self._build_ui_blocks()
return self._ui_block
def mount_in_app(self, app: FastAPI, path: str) -> None:
blocks = self.get_ui_blocks()
blocks.queue()
logger.info("Mounting the gradio UI, at path=%s", path)
gr.mount_gradio_app(app, blocks, path=path)
if __name__ == "__main__":
llm_component = LLMComponent()
vector_store_component = VectorStoreComponent()
embedding_component = EmbeddingComponent()
node_store_component = NodeStoreComponent()
ingest_service = IngestService(
llm_component, vector_store_component, embedding_component, node_store_component
)
chat_service = ChatService(
llm_component, vector_store_component, embedding_component, node_store_component
)
ui = PrivateGptUi(ingest_service, chat_service)
_blocks = ui.get_ui_blocks()
_blocks.queue()
_blocks.launch(debug=False, show_api=False)
|