Maharshi Gor
Restructure output panel rendering. This change fixes logprob issue and
1758388
raw
history blame
8.4 kB
import gradio as gr
import numpy as np
from loguru import logger
from app_configs import AVAILABLE_MODELS, UNSELECTED_VAR_NAME
from components import commons
from components.structs import TossupPipelineState
from components.typed_dicts import TossupPipelineStateDict
from display.formatting import tiny_styled_warning
from workflows.structs import Buzzer, TossupWorkflow
from .model_pipeline import PipelineInterface, PipelineState, PipelineUIState
def toggleable_slider(
value, minimum, maximum, step, toggle_value=False, label=None, info=None, min_width=200, scale=1
):
with gr.Column(elem_classes="toggleable", min_width=min_width, scale=scale):
show_label = label is not None
checkbox = gr.Checkbox(label=label, value=toggle_value, container=False, info=info, show_label=show_label)
slider = gr.Slider(
minimum=minimum,
maximum=maximum,
value=value,
step=step,
label="",
interactive=True,
show_label=False,
container=False,
)
checkbox.change(fn=lambda x: gr.update(interactive=x), inputs=[checkbox], outputs=[slider])
return checkbox, slider
class TossupPipelineInterface(PipelineInterface):
def __init__(
self,
app: gr.Blocks,
workflow: TossupWorkflow,
ui_state: PipelineUIState | None = None,
model_options: list[str] = None,
config: dict = {},
):
super().__init__(app, workflow, ui_state, model_options, config)
self.buzzer_state = gr.State(workflow.buzzer.model_dump())
self.pipeline_state.change(
lambda x: TossupPipelineState(**x).workflow.buzzer.model_dump(),
inputs=[self.pipeline_state],
outputs=[self.buzzer_state],
)
def update_prob_slider(
self, state_dict: TossupPipelineStateDict, answer_var: str, tokens_prob: float | None
) -> tuple[TossupPipelineStateDict, dict, dict, dict]:
"""Update the probability slider based on the answer variable."""
state = TossupPipelineState(**state_dict)
if answer_var == UNSELECTED_VAR_NAME:
return (
state.model_dump(),
gr.update(interactive=True),
gr.update(value="AND", interactive=True),
gr.update(visible=False),
)
logprobs_supported = state.workflow.is_token_probs_supported(answer_var)
buzzer = state.workflow.buzzer
tokens_prob_threshold = tokens_prob if logprobs_supported else None
method = buzzer.method if logprobs_supported else "AND"
state.workflow.buzzer = Buzzer(
method=method,
confidence_threshold=buzzer.confidence_threshold,
prob_threshold=tokens_prob_threshold,
)
model_name = state.workflow.get_answer_model(answer_var)
return (
state.model_dump(),
gr.update(interactive=logprobs_supported, value=tokens_prob if logprobs_supported else 0.0),
gr.update(value=method, interactive=logprobs_supported),
gr.update(
value=tiny_styled_warning(
f"<code>'{model_name}'</code> does not support <code>'logprobs'</code>. The probability slider will be disabled."
),
visible=not logprobs_supported,
),
)
def _render_buzzer_panel(
self, buzzer: Buzzer, prob_slider_supported: bool, selected_model_name: str | None = None
):
with gr.Row(elem_classes="control-panel"):
self.confidence_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=buzzer.confidence_threshold,
step=0.01,
label="Confidence",
elem_classes="slider-container",
show_reset_button=False,
)
value = buzzer.method if prob_slider_supported else "AND"
self.buzzer_method_dropdown = gr.Dropdown(
choices=["AND", "OR"],
value=value,
label="Method",
interactive=prob_slider_supported,
min_width=80,
scale=0,
)
self.prob_slider = gr.Slider(
value=buzzer.prob_threshold or 0.0,
interactive=prob_slider_supported,
label="Probability",
minimum=0.0,
maximum=1.0,
step=0.001,
elem_classes="slider-container",
show_reset_button=False,
)
display_html = ""
if selected_model_name is not None:
display_html = tiny_styled_warning(
f"<code>{selected_model_name}</code> does not support <code>logprobs</code>. The probability slider will be disabled."
)
self.buzzer_warning_display = gr.HTML(display_html, visible=not prob_slider_supported)
def _render_output_panel(self, pipeline_state: TossupPipelineState):
dropdowns = {}
available_variables = pipeline_state.workflow.get_available_variables()
variable_options = [UNSELECTED_VAR_NAME] + [v for v in available_variables if v not in self.input_variables]
with gr.Column(elem_classes="step-accordion control-panel"):
commons.get_panel_header(
header="Final output variables mapping:",
)
with gr.Row(elem_classes="output-fields-row"):
for output_field in self.required_output_variables:
value = pipeline_state.workflow.outputs.get(output_field)
value = value or UNSELECTED_VAR_NAME
dropdown = gr.Dropdown(
label=output_field,
value=value,
choices=variable_options,
interactive=True,
elem_classes="output-field-variable",
# show_label=False,
)
dropdown.change(
self.sm.update_output_variables,
inputs=[self.pipeline_state, gr.State(output_field), dropdown],
outputs=[self.pipeline_state],
)
dropdowns[output_field] = dropdown
commons.get_panel_header(
header="Buzzer settings:",
subheader="Set your thresholds for confidence and output tokens probability (computed using <code>logprobs</code>).",
)
logprobs_supported = pipeline_state.workflow.is_token_probs_supported()
selected_model_name = pipeline_state.workflow.get_answer_model()
self._render_buzzer_panel(pipeline_state.workflow.buzzer, logprobs_supported, selected_model_name)
def update_choices(available_variables: list[str]):
"""Update the choices for the dropdowns"""
return [gr.update(choices=available_variables, value=None, selected=None) for _ in dropdowns.values()]
self.variables_state.change(
update_choices,
inputs=[self.variables_state],
outputs=list(dropdowns.values()),
)
# Updating the pipeline buzzer on user input changes in Buzzer panel.
gr.on(
triggers=[
self.confidence_slider.release,
self.buzzer_method_dropdown.input,
self.prob_slider.release,
],
fn=self.sm.update_buzzer,
inputs=[self.pipeline_state, self.confidence_slider, self.buzzer_method_dropdown, self.prob_slider],
outputs=[self.pipeline_state],
)
# THIS WASN't NEEDED SINCE WE RERENDER THE OUTPUT PANEL ENTIRELY ON CHANGES
# answer_dropdown = dropdowns["answer"]
# if answer_dropdown is not None:
# gr.on(
# triggers=[self.buzzer_state.change],
# fn=self.update_prob_slider,
# inputs=[self.buzzer_state, answer_dropdown, self.prob_slider],
# outputs=[
# self.pipeline_state,
# self.prob_slider,
# self.buzzer_method_dropdown,
# self.buzzer_warning_display,
# ],
# )