import torch import streamlit as st import time from app_lib.user_input import ( get_cardinality, get_class_name, get_concepts, get_image, get_model_name, ) from app_lib.test import test def _disable(): st.session_state.disabled = True def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")): columns = st.columns([0.40, 0.60]) with columns[0]: model_name = get_model_name() row1 = st.columns(2) row2 = st.columns(2) with row1[0]: image = get_image() st.image(image, use_column_width=True) with row1[1]: class_name, class_ready, class_error = get_class_name() concepts, concepts_ready, concepts_error = get_concepts() cardinality = get_cardinality(concepts, concepts_ready) with row2[0]: change_image_button = st.button( "Change Image", use_container_width=True, disabled=st.session_state.disabled, ) if change_image_button: st.session_state.sidebar_state = "expanded" st.experimental_rerun() with row2[1]: ready = class_ready and concepts_ready error_message = "" if class_error is not None: error_message += f"- {class_error}\n" if concepts_error is not None: error_message += f"- {concepts_error}\n" if error_message: st.error(error_message) test_button = st.button( "Test", use_container_width=True, on_click=_disable, disabled=st.session_state.disabled or not ready, ) with columns[1]: _, centercol, _ = st.columns(3) with centercol: if test_button: test( image, class_name, concepts, cardinality, "imagenette", model_name, device, )