Spaces:
Runtime error
Runtime error
File size: 1,897 Bytes
ab87be2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import streamlit as st
from src.architectures import *
from src.common import generate_group_tag
from src.testing import TestGenerator
from src.st_helpers import st_setup
if Architecture.architectures is None:
Architecture.load_architectures()
if st_setup('LLM Arch'):
summary = st.container()
with summary:
st.write("# Test Runner")
st.write("## Run a new test")
st.write("### Comment:")
comment = st.text_input("Optional comment for the test")
st.write("### Architectures to include:")
selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures])
st.write("### Number of questions to ask:")
q_count = st.slider(label="Number of questions", min_value=1, max_value=TestGenerator.question_count(), step=1)
st.write("### Tag:")
tag = generate_group_tag()
st.write(f'Test will be tagged as "{tag}" - record this for easy searching later')
total_tests = len(selected_archs) * q_count
st.write("### Run:")
st.write(f"**{total_tests}** total tests will be run")
if st.button("**Run**", disabled=(total_tests==0)):
progress = st.progress(0.0, text="Running tests...")
questions = TestGenerator.get_random_questions(q_count)
num_complete = 0
for arch_name in selected_archs:
architecture = Architecture.get_architecture(arch_name)
for q in questions:
architecture(ArchitectureRequest(q), trace_tags=[tag, "TestRunner"], trace_comment=comment)
num_complete += 1
if num_complete == total_tests:
progress.empty()
else:
progress.progress(num_complete/total_tests, f"Run {num_complete} of {total_tests} tests...")
|