|
from typing import Literal |
|
|
|
from app_configs import CONFIGS, PLAYGROUND_MODELS, SUBMISSION_MODELS |
|
from display.formatting import styled_error |
|
from shared.workflows.errors import ProviderAPIError, WorkflowExecutionError |
|
from shared.workflows.structs import TossupWorkflow, Workflow |
|
from shared.workflows.validators import ValidationErrorType, WorkflowValidationError, WorkflowValidator |
|
|
|
|
|
class UnsupportedModelError(Exception): |
|
"""Exception for unsupported model errors.""" |
|
|
|
def __init__(self, model_name: str, playground: bool = False): |
|
self.playground = playground |
|
self.model_name = model_name |
|
|
|
def __str__(self): |
|
if self.playground: |
|
return f"Model '{self.model_name}' is not supported in playground mode. Please use one of the supported playground models: {PLAYGROUND_MODELS}.\n But, you can use the full suite for submission." |
|
return ( |
|
f"Model '{self.model_name}' is not supported. Please use one of the supported models: {SUBMISSION_MODELS}." |
|
) |
|
|
|
|
|
def create_error_message(e: Exception) -> str: |
|
"""Create an error message for a given exception.""" |
|
if isinstance(e, UnsupportedModelError): |
|
return styled_error(str(e)) |
|
if isinstance(e, ProviderAPIError): |
|
return styled_error( |
|
f"Our {e.provider} models are currently experiencing issues. Please try again later. \n\nIf the problem persists, please contact support." |
|
) |
|
elif isinstance(e, WorkflowExecutionError): |
|
return styled_error( |
|
f"Workflow execution failed: {e}. Please try again later. \n\nIf the problem persists, please contact support." |
|
) |
|
elif isinstance(e, ValueError): |
|
return styled_error( |
|
f"Invalid input -- {e}. Please try again. \n\nIf the problem persists, please contact support." |
|
) |
|
else: |
|
return styled_error("An unexpected error occurred. Please contact support.") |
|
|
|
|
|
class InterfaceWorkflowValidator: |
|
def __init__(self, mode: Literal["tossup", "bonus"], editing: bool = False): |
|
self.mode = mode |
|
self.required_input_vars = CONFIGS[mode]["required_input_vars"] |
|
self.required_output_vars = CONFIGS[mode]["required_output_vars"] |
|
self.editing = editing |
|
|
|
def __call__(self, workflow: TossupWorkflow | Workflow, playground: bool = False): |
|
input_vars = set(workflow.inputs) |
|
for req_var in self.required_input_vars: |
|
if req_var not in input_vars: |
|
default_str = "inputs:\n" + "\n".join([f"- {var}" for var in self.required_input_vars]) |
|
raise ValueError( |
|
f"Missing required input variable: '{req_var}'. " |
|
"\nDon't modify the 'inputs' field in the workflow. " |
|
"Please set it back to:" |
|
f"\n{default_str}" |
|
) |
|
|
|
output_vars = set(workflow.outputs) |
|
for req_var in self.required_output_vars: |
|
if req_var not in output_vars: |
|
default_str = "[" + ", ".join([f"'{var}'" for var in self.required_output_vars]) + "]" |
|
raise ValueError( |
|
f"Missing required output variable: '{req_var}'. " |
|
"\nDon't remove the keys from the 'outputs' field in the workflow. Only update their values." |
|
f"\nMake sure you have values set for all the outputs: {default_str}" |
|
) |
|
|
|
|
|
allowed_models = PLAYGROUND_MODELS if playground else SUBMISSION_MODELS |
|
self.validator = WorkflowValidator(allowed_model_names=allowed_models) |
|
try: |
|
self.validator.validate(workflow, allow_empty=self.editing) |
|
except WorkflowValidationError as e: |
|
if e.errors and e.errors[0].error_type == ValidationErrorType.UNSUPPORTED_MODEL: |
|
step_id = e.errors[0].step_id |
|
model_name = workflow.steps[step_id].get_full_model_name() |
|
raise UnsupportedModelError(model_name, playground=playground) |
|
error_msg_total = f"Found {len(e.errors)} errors in the workflow:\n" |
|
error_msg_list = [f"- {err.message}" for err in e.errors] |
|
error_msg = error_msg_total + "\n".join(error_msg_list) |
|
raise ValueError(error_msg) |
|
|
|
def validate_state_dict(self, state_dict: dict, playground: bool = False): |
|
"""Validate a state dictionary.""" |
|
if self.mode == "tossup": |
|
workflow = TossupWorkflow(**state_dict["workflow"]) |
|
else: |
|
workflow = Workflow(**state_dict["workflow"]) |
|
self(workflow, playground=playground) |
|
return workflow |
|
|