Spaces:
Runtime error
Runtime error
""" | |
This file defines a useful high-level abstraction to build Gradio chatbots: ChatInterface. | |
""" | |
from __future__ import annotations | |
import inspect | |
from typing import AsyncGenerator, Callable | |
import anyio | |
from gradio_client import utils as client_utils | |
from gradio_client.documentation import document, set_documentation_group | |
from gradio.blocks import Blocks | |
from gradio.components import ( | |
Button, | |
Chatbot, | |
IOComponent, | |
Markdown, | |
State, | |
Textbox, | |
get_component_instance, | |
) | |
from gradio.events import Dependency, EventListenerMethod, on | |
from gradio.helpers import create_examples as Examples # noqa: N812 | |
from gradio.layouts import Accordion, Column, Group, Row | |
from gradio.themes import ThemeClass as Theme | |
from gradio.utils import SyncToAsyncIterator, async_iteration | |
set_documentation_group("chatinterface") | |
class LLM4SciLitChatInterface(Blocks): | |
""" | |
ChatInterface is Gradio's high-level abstraction for creating chatbot UIs, and allows you to create | |
a web-based demo around a chatbot model in a few lines of code. Only one parameter is required: fn, which | |
takes a function that governs the response of the chatbot based on the user input and chat history. Additional | |
parameters can be used to control the appearance and behavior of the demo. | |
Example: | |
import gradio as gr | |
def echo(message, history): | |
return message | |
demo = gr.ChatInterface(fn=echo, examples=["hello", "hola", "merhaba"], title="Echo Bot") | |
demo.launch() | |
Demos: chatinterface_random_response, chatinterface_streaming_echo | |
Guides: creating-a-chatbot-fast, sharing-your-app | |
""" | |
def __init__( | |
self, | |
fn: Callable, | |
*, | |
chatbot: Chatbot | None = None, | |
textbox: Textbox | None = None, | |
additional_inputs: str | IOComponent | list[str | IOComponent] | None = None, | |
additional_inputs_accordion_name: str = "Additional Inputs", | |
examples: list[str] | None = None, | |
cache_examples: bool | None = None, | |
title: str | None = None, | |
description: str | None = None, | |
theme: Theme | str | None = None, | |
css: str | None = None, | |
analytics_enabled: bool | None = None, | |
submit_btn: str | None | Button = "Submit", | |
stop_btn: str | None | Button = "Stop", | |
retry_btn: str | None | Button = "π Retry", | |
undo_btn: str | None | Button = "β©οΈ Undo", | |
clear_btn: str | None | Button = "ποΈ Clear", | |
autofocus: bool = True, | |
): | |
""" | |
Parameters: | |
fn: the function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format. | |
chatbot: an instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created. | |
textbox: an instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created. | |
additional_inputs: an instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion. | |
additional_inputs_accordion_name: the label of the accordion to use for additional inputs, only used if additional_inputs is provided. | |
examples: sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input. | |
cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False. | |
title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window. | |
description: a description for the interface; if provided, appears above the chatbot and beneath the title in regular font. Accepts Markdown and HTML content. | |
theme: Theme to use, loaded from gradio.themes. | |
css: custom css or path to custom css file to use with interface. | |
analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True. | |
submit_btn: Text to display on the submit button. If None, no button will be displayed. If a Button object, that button will be used. | |
stop_btn: Text to display on the stop button, which replaces the submit_btn when the submit_btn or retry_btn is clicked and response is streaming. Clicking on the stop_btn will halt the chatbot response. If set to None, stop button functionality does not appear in the chatbot. If a Button object, that button will be used as the stop button. | |
retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used. | |
undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used. | |
clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used. | |
autofocus: If True, autofocuses to the textbox when the page loads. | |
""" | |
super().__init__( | |
analytics_enabled=analytics_enabled, | |
mode="chat_interface", | |
css=css, | |
title=title or "Gradio", | |
theme=theme, | |
) | |
self.fn = fn | |
self.is_async = inspect.iscoroutinefunction( | |
self.fn | |
) or inspect.isasyncgenfunction(self.fn) | |
self.is_generator = inspect.isgeneratorfunction( | |
self.fn | |
) or inspect.isasyncgenfunction(self.fn) | |
self.examples = examples | |
if self.space_id and cache_examples is None: | |
self.cache_examples = True | |
else: | |
self.cache_examples = cache_examples or False | |
self.buttons: list[Button] = [] | |
if additional_inputs: | |
if not isinstance(additional_inputs, list): | |
additional_inputs = [additional_inputs] | |
self.additional_inputs = [ | |
get_component_instance(i) for i in additional_inputs # type: ignore | |
] | |
else: | |
self.additional_inputs = [] | |
self.additional_inputs_accordion_name = additional_inputs_accordion_name | |
self.additional_outputs = [] | |
with self: | |
if title: | |
Markdown( | |
f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>" | |
) | |
if description: | |
Markdown(description) | |
with Row(): | |
with Column(variant="panel", scale=1): | |
if chatbot: | |
self.chatbot = chatbot.render() | |
else: | |
self.chatbot = Chatbot(label="Chatbot") | |
with Group(): | |
with Row(): | |
if textbox: | |
textbox.container = False | |
textbox.show_label = False | |
self.textbox = textbox.render() | |
else: | |
self.textbox = Textbox( | |
container=False, | |
show_label=False, | |
label="Message", | |
placeholder="Type a message...", | |
scale=7, | |
autofocus=autofocus, | |
) | |
if submit_btn: | |
if isinstance(submit_btn, Button): | |
submit_btn.render() | |
elif isinstance(submit_btn, str): | |
submit_btn = Button( | |
submit_btn, | |
variant="primary", | |
scale=1, | |
min_width=150, | |
) | |
else: | |
raise ValueError( | |
f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}" | |
) | |
if stop_btn: | |
if isinstance(stop_btn, Button): | |
stop_btn.visible = False | |
stop_btn.render() | |
elif isinstance(stop_btn, str): | |
stop_btn = Button( | |
stop_btn, | |
variant="stop", | |
visible=False, | |
scale=1, | |
min_width=150, | |
) | |
else: | |
raise ValueError( | |
f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}" | |
) | |
self.buttons.extend([submit_btn, stop_btn]) | |
with Row(): | |
for btn in [retry_btn, undo_btn, clear_btn]: | |
if btn: | |
if isinstance(btn, Button): | |
btn.render() | |
elif isinstance(btn, str): | |
btn = Button(btn, variant="secondary") | |
else: | |
raise ValueError( | |
f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}" | |
) | |
self.buttons.append(btn) | |
self.fake_api_btn = Button("Fake API", visible=False) | |
self.fake_response_textbox = Textbox( | |
label="Response", visible=False | |
) | |
( | |
self.submit_btn, | |
self.stop_btn, | |
self.retry_btn, | |
self.undo_btn, | |
self.clear_btn, | |
) = self.buttons | |
with Column(variant="panel", scale=2): | |
for i in range(4): | |
self.additional_outputs.append( | |
Textbox( | |
interactive=False, | |
label=f"Document {i+1}" | |
) | |
) | |
if examples: | |
if self.is_generator: | |
examples_fn = self._examples_stream_fn | |
else: | |
examples_fn = self._examples_fn | |
self.examples_handler = Examples( | |
examples=examples, | |
inputs=[self.textbox] + self.additional_inputs, | |
outputs=self.chatbot, | |
fn=examples_fn, | |
) | |
any_unrendered_inputs = any( | |
not inp.is_rendered for inp in self.additional_inputs | |
) | |
if self.additional_inputs and any_unrendered_inputs: | |
with Accordion(self.additional_inputs_accordion_name, open=False): | |
for input_component in self.additional_inputs: | |
if not input_component.is_rendered: | |
input_component.render() | |
# The example caching must happen after the input components have rendered | |
if cache_examples: | |
client_utils.synchronize_async(self.examples_handler.cache) | |
self.saved_input = State() | |
self.chatbot_state = State([]) | |
self._setup_events() | |
self._setup_api() | |
def _setup_events(self) -> None: | |
submit_fn = self._stream_fn if self.is_generator else self._submit_fn | |
submit_triggers = ( | |
[self.textbox.submit, self.submit_btn.click] | |
if self.submit_btn | |
else [self.textbox.submit] | |
) | |
submit_event = ( | |
on( | |
submit_triggers, | |
self._clear_and_save_textbox, | |
[self.textbox], | |
[self.textbox, self.saved_input], | |
api_name=False, | |
queue=False, | |
) | |
.then( | |
self._display_input, | |
[self.saved_input, self.chatbot_state], | |
[self.chatbot, self.chatbot_state], | |
api_name=False, | |
queue=False, | |
) | |
.then( | |
submit_fn, | |
[self.saved_input, self.chatbot_state] + self.additional_inputs, | |
[self.chatbot, self.chatbot_state] + self.additional_outputs, | |
api_name=False, | |
) | |
) | |
self._setup_stop_events(submit_triggers, submit_event) | |
if self.retry_btn: | |
retry_event = ( | |
self.retry_btn.click( | |
self._delete_prev_fn, | |
[self.chatbot_state], | |
[self.chatbot, self.saved_input, self.chatbot_state], | |
api_name=False, | |
queue=False, | |
) | |
.then( | |
self._display_input, | |
[self.saved_input, self.chatbot_state], | |
[self.chatbot, self.chatbot_state], | |
api_name=False, | |
queue=False, | |
) | |
.then( | |
submit_fn, | |
[self.saved_input, self.chatbot_state] + self.additional_inputs, | |
[self.chatbot, self.chatbot_state], | |
api_name=False, | |
) | |
) | |
self._setup_stop_events([self.retry_btn.click], retry_event) | |
if self.undo_btn: | |
self.undo_btn.click( | |
self._delete_prev_fn, | |
[self.chatbot_state], | |
[self.chatbot, self.saved_input, self.chatbot_state], | |
api_name=False, | |
queue=False, | |
).then( | |
lambda x: x, | |
[self.saved_input], | |
[self.textbox], | |
api_name=False, | |
queue=False, | |
) | |
if self.clear_btn: | |
self.clear_btn.click( | |
lambda: ([], [], None), | |
None, | |
[self.chatbot, self.chatbot_state, self.saved_input], | |
queue=False, | |
api_name=False, | |
) | |
def _setup_stop_events( | |
self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency | |
) -> None: | |
if self.stop_btn and self.is_generator: | |
if self.submit_btn: | |
for event_trigger in event_triggers: | |
event_trigger( | |
lambda: ( | |
Button.update(visible=False), | |
Button.update(visible=True), | |
), | |
None, | |
[self.submit_btn, self.stop_btn], | |
api_name=False, | |
queue=False, | |
) | |
event_to_cancel.then( | |
lambda: (Button.update(visible=True), Button.update(visible=False)), | |
None, | |
[self.submit_btn, self.stop_btn], | |
api_name=False, | |
queue=False, | |
) | |
else: | |
for event_trigger in event_triggers: | |
event_trigger( | |
lambda: Button.update(visible=True), | |
None, | |
[self.stop_btn], | |
api_name=False, | |
queue=False, | |
) | |
event_to_cancel.then( | |
lambda: Button.update(visible=False), | |
None, | |
[self.stop_btn], | |
api_name=False, | |
queue=False, | |
) | |
self.stop_btn.click( | |
None, | |
None, | |
None, | |
cancels=event_to_cancel, | |
api_name=False, | |
) | |
def _setup_api(self) -> None: | |
api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn | |
self.fake_api_btn.click( | |
api_fn, | |
[self.textbox, self.chatbot_state] + self.additional_inputs, | |
[self.textbox, self.chatbot_state], | |
api_name="chat", | |
) | |
def _clear_and_save_textbox(self, message: str) -> tuple[str, str]: | |
return "", message | |
def _display_input( | |
self, message: str, history: list[list[str | None]] | |
) -> tuple[list[list[str | None]], list[list[str | None]]]: | |
history.append([message, None]) | |
return history, history | |
async def _submit_fn( | |
self, | |
message: str, | |
history_with_input: list[list[str | None]], | |
*args, | |
) -> tuple[list[list[str | None]], list[list[str | None]]]: | |
history = history_with_input[:-1] | |
if self.is_async: | |
[response, *other_outputs] = await self.fn(message, history, *args) | |
else: | |
[response, *other_outputs] = await anyio.to_thread.run_sync( | |
self.fn, message, history, *args, limiter=self.limiter | |
) | |
history.append([message, response]) | |
return history, history, *other_outputs | |
async def _stream_fn( | |
self, | |
message: str, | |
history_with_input: list[list[str | None]], | |
*args, | |
) -> AsyncGenerator: | |
history = history_with_input[:-1] | |
if self.is_async: | |
generator = self.fn(message, history, *args) | |
else: | |
generator = await anyio.to_thread.run_sync( | |
self.fn, message, history, *args, limiter=self.limiter | |
) | |
generator = SyncToAsyncIterator(generator, self.limiter) | |
try: | |
first_response = await async_iteration(generator) | |
update = history + [[message, first_response]] | |
yield update, update | |
except StopIteration: | |
update = history + [[message, None]] | |
yield update, update | |
async for response in generator: | |
update = history + [[message, response]] | |
yield update, update | |
async def _api_submit_fn( | |
self, message: str, history: list[list[str | None]], *args | |
) -> tuple[str, list[list[str | None]]]: | |
if self.is_async: | |
response = await self.fn(message, history, *args) | |
else: | |
response = await anyio.to_thread.run_sync( | |
self.fn, message, history, *args, limiter=self.limiter | |
) | |
history.append([message, response]) | |
return response, history | |
async def _api_stream_fn( | |
self, message: str, history: list[list[str | None]], *args | |
) -> AsyncGenerator: | |
if self.is_async: | |
generator = self.fn(message, history, *args) | |
else: | |
generator = await anyio.to_thread.run_sync( | |
self.fn, message, history, *args, limiter=self.limiter | |
) | |
generator = SyncToAsyncIterator(generator, self.limiter) | |
try: | |
first_response = await async_iteration(generator) | |
yield first_response, history + [[message, first_response]] | |
except StopIteration: | |
yield None, history + [[message, None]] | |
async for response in generator: | |
yield response, history + [[message, response]] | |
async def _examples_fn(self, message: str, *args) -> list[list[str | None]]: | |
if self.is_async: | |
response = await self.fn(message, [], *args) | |
else: | |
response = await anyio.to_thread.run_sync( | |
self.fn, message, [], *args, limiter=self.limiter | |
) | |
return [[message, response]] | |
async def _examples_stream_fn( | |
self, | |
message: str, | |
*args, | |
) -> AsyncGenerator: | |
if self.is_async: | |
generator = self.fn(message, [], *args) | |
else: | |
generator = await anyio.to_thread.run_sync( | |
self.fn, message, [], *args, limiter=self.limiter | |
) | |
generator = SyncToAsyncIterator(generator, self.limiter) | |
async for response in generator: | |
yield [[message, response]] | |
def _delete_prev_fn( | |
self, history: list[list[str | None]] | |
) -> tuple[list[list[str | None]], str, list[list[str | None]]]: | |
try: | |
message, _ = history.pop() | |
except IndexError: | |
message = "" | |
return history, message or "", history | |