llm-arch / pages /030_Test_Runner.py
alfraser's picture
Tidied up some comments
5f6f1d0
raw
history blame
5.91 kB
import regex as re
import streamlit as st
from pandas import DataFrame
from random import sample
from src.architectures import *
from src.common import generate_group_tag
from src.datatypes import *
from src.testing import TestGenerator, batch_test
from src.st_helpers import st_setup
def display_custom_test():
"""
Write to the UI the ability to run some customised tests
"""
st.write("## Run a new custom test")
st.write("### Comment:")
comment = st.text_input("Optional comment for the test", key="custom_test_comment")
st.write("### Architectures to include:")
selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="custom_test_archs")
st.write("### Number of questions to ask:")
q_count = st.slider(label="Number of questions", min_value=1, max_value=TestGenerator.question_count(), step=1, key="custom_q_count")
st.write("### Number of threads to use for testing:")
thread_count = st.slider(label="Number of threads", min_value=1, max_value=64, step=1, value=16, key="custom_thread_slider")
st.write("### Tag:")
tag = generate_group_tag()
st.write(f'Test will be tagged as "{tag}" - record this for easy searching later')
total_tests = len(selected_archs) * q_count
st.write("### Run:")
st.write(f"**{total_tests}** total tests will be run")
if st.button("**Run**", disabled=(total_tests == 0), key="custom_test_button"):
with st.spinner():
questions = TestGenerator.get_random_questions(q_count)
batch_test(questions=questions, architectures=selected_archs,
trace_tags=[tag, "TestRunner"], trace_comment=comment,
num_workers=thread_count)
def display_pricing_fact_test():
"""
Write to the UI the ability to run some of the pricing fact test - asking the
architectures to state the prices for given products and calculating the resuting accuracy
"""
def get_question_price_pairs():
"""
Returns the complete list of pricing questions along with the correct prices
"""
DataLoader.load_data()
pairs = []
for p in Product.all.values():
price = p.price
product_name = p.name
category_name = p.category.lower_singular_name
if category_name == "tv":
category_name = "TV"
question = f'How much is the {product_name} {category_name}?'
pairs.append((question, price))
return pairs
def get_price_from_response(response: str) -> float:
"""
Parses a string with regex to get a float value for comparison to what the
architecture said the price was
"""
prices = re.findall('\$[,\d]+\.\d\d', response)
if len(prices) == 0:
return -0.1
return float(prices[0][1:].replace(',',''))
st.write("## Run a pricing fact test")
st.write("### Comment:")
comment = st.text_input("Optional comment for the test", key="pricing_test_comment")
st.write("### Architectures to include:")
selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="pricing_test_archs")
question_price_pairs = get_question_price_pairs()
st.write("### Number of questions to ask:")
q_count = st.slider(label="Number of questions", min_value=1, max_value=len(question_price_pairs), step=1, key="pricing_q_count")
st.write("### Number of threads to use for testing:")
thread_count = st.slider(label="Number of threads", min_value=1, max_value=64, step=1, value=16, key="pricing_thread_slider")
st.write("### Tag:")
tag = generate_group_tag()
st.write(f'Test will be tagged as "{tag}" - record this for easy searching later')
total_tests = len(selected_archs) * q_count
st.write("### Run:")
st.write(f"**{total_tests}** total tests will be run")
if st.button("**Run**", disabled=(total_tests == 0), key="pricing_test_button"):
question_price_pairs = sample(question_price_pairs, k=q_count)
question_price_dict = {qpp[0]: qpp[1] for qpp in question_price_pairs}
questions = list(question_price_dict.keys())
answer_stats = {}
for arch_name in selected_archs:
answer_stats[arch_name] = [0, 0] # [correct, incorrect]
with st.spinner():
results: List[Tuple[str, str, str]] = batch_test(questions=questions, architectures=selected_archs,
trace_tags=[tag, "TestRunner"], trace_comment=comment,
num_workers=thread_count)
for arch, query, response in results:
target_price = question_price_dict[query]
answer_price = get_price_from_response(response)
if target_price == answer_price:
answer_stats[arch][0] += 1
else:
answer_stats[arch][1] += 1
table_data = []
for arch_name in selected_archs:
correct = answer_stats[arch_name][0]
incorrect = answer_stats[arch_name][1]
total = correct + incorrect
percent_correct = round(correct / total * 100, 1)
table_data.append([arch_name, correct, incorrect, total, f'{percent_correct:.1f}%'])
df = DataFrame(table_data, columns=['Architecture', 'Correct', 'Incorrect', 'Total', '% Correct'])
st.table(df.assign(no_index='').set_index('no_index'))
if Architecture.architectures is None:
Architecture.load_architectures()
if st_setup('LLM Arch'):
st.write("# Test Runner")
with st.expander("Pricing Fact Tests"):
display_pricing_fact_test()
with st.expander("Custom Tests"):
display_custom_test()