Spaces:
Runtime error
Runtime error
""" | |
This page allows users to run batches of tests against the architectures. It allows the asking of | |
general questions from the question set, and also specific pricing questions which are designed | |
to test and demonstrate the architectures' ability to learn exact facts. | |
""" | |
import regex as re | |
import streamlit as st | |
from pandas import DataFrame | |
from random import sample | |
from src.architectures import * | |
from src.common import generate_group_tag | |
from src.datatypes import * | |
from src.testing import TestGenerator, batch_test | |
from src.st_helpers import st_setup | |
def display_custom_test(): | |
""" | |
Write to the UI the ability to run some customised tests | |
""" | |
st.write("## Run a new custom test") | |
st.write("### Comment:") | |
comment = st.text_input("Optional comment for the test", key="custom_test_comment") | |
st.write("### Architectures to include:") | |
selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="custom_test_archs") | |
st.write("### Number of questions to ask:") | |
q_count = st.slider(label="Number of questions", min_value=1, max_value=TestGenerator.question_count(), step=1, key="custom_q_count") | |
st.write("### Number of threads to use for testing:") | |
thread_count = st.slider(label="Number of threads", min_value=1, max_value=64, step=1, value=16, key="custom_thread_slider") | |
st.write("### Tag:") | |
tag = generate_group_tag() | |
st.write(f'Test will be tagged as "{tag}" - record this for easy searching later') | |
total_tests = len(selected_archs) * q_count | |
st.write("### Run:") | |
st.write(f"**{total_tests}** total tests will be run") | |
if st.button("**Run**", disabled=(total_tests == 0), key="custom_test_button"): | |
with st.spinner(): | |
questions = TestGenerator.get_random_questions(q_count) | |
batch_test(questions=questions, architectures=selected_archs, | |
trace_tags=[tag, "TestRunner"], trace_comment=comment, | |
num_workers=thread_count) | |
def display_pricing_fact_test(): | |
""" | |
Write to the UI the ability to run some of the pricing fact test - asking the | |
architectures to state the prices for given products and calculating the resuting accuracy | |
""" | |
def get_question_price_pairs(): | |
""" | |
Returns the complete list of pricing questions along with the correct prices | |
""" | |
DataLoader.load_data() | |
pairs = [] | |
for p in Product.all.values(): | |
price = p.price | |
product_name = p.name | |
category_name = p.category.lower_singular_name | |
if category_name == "tv": | |
category_name = "TV" | |
question = f'How much is the {product_name} {category_name}?' | |
pairs.append((question, price)) | |
return pairs | |
def get_price_from_response(response: str) -> float: | |
""" | |
Parses a string with regex to get a float value for comparison to what the | |
architecture said the price was | |
""" | |
prices = re.findall('\$[,\d]+\.\d\d', response) | |
if len(prices) == 0: | |
return -0.1 | |
return float(prices[0][1:].replace(',','')) | |
st.write("## Run a pricing fact test") | |
st.write("### Comment:") | |
comment = st.text_input("Optional comment for the test", key="pricing_test_comment") | |
st.write("### Architectures to include:") | |
selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="pricing_test_archs") | |
question_price_pairs = get_question_price_pairs() | |
st.write("### Number of questions to ask:") | |
q_count = st.slider(label="Number of questions", min_value=1, max_value=len(question_price_pairs), step=1, key="pricing_q_count") | |
st.write("### Number of threads to use for testing:") | |
thread_count = st.slider(label="Number of threads", min_value=1, max_value=64, step=1, value=16, key="pricing_thread_slider") | |
st.write("### Tag:") | |
tag = generate_group_tag() | |
st.write(f'Test will be tagged as "{tag}" - record this for easy searching later') | |
total_tests = len(selected_archs) * q_count | |
st.write("### Run:") | |
st.write(f"**{total_tests}** total tests will be run") | |
if st.button("**Run**", disabled=(total_tests == 0), key="pricing_test_button"): | |
question_price_pairs = sample(question_price_pairs, k=q_count) | |
question_price_dict = {qpp[0]: qpp[1] for qpp in question_price_pairs} | |
questions = list(question_price_dict.keys()) | |
answer_stats = {} | |
for arch_name in selected_archs: | |
answer_stats[arch_name] = [0, 0] # [correct, incorrect] | |
with st.spinner(): | |
results: List[Tuple[str, str, str]] = batch_test(questions=questions, architectures=selected_archs, | |
trace_tags=[tag, "TestRunner"], trace_comment=comment, | |
num_workers=thread_count) | |
for arch, query, response in results: | |
target_price = question_price_dict[query] | |
answer_price = get_price_from_response(response) | |
if target_price == answer_price: | |
answer_stats[arch][0] += 1 | |
else: | |
answer_stats[arch][1] += 1 | |
table_data = [] | |
for arch_name in selected_archs: | |
correct = answer_stats[arch_name][0] | |
incorrect = answer_stats[arch_name][1] | |
total = correct + incorrect | |
percent_correct = round(correct / total * 100, 1) | |
table_data.append([arch_name, correct, incorrect, total, f'{percent_correct:.1f}%']) | |
df = DataFrame(table_data, columns=['Architecture', 'Correct', 'Incorrect', 'Total', '% Correct']) | |
st.table(df.assign(no_index='').set_index('no_index')) | |
if Architecture.architectures is None: | |
Architecture.load_architectures() | |
if st_setup('LLM Arch'): | |
st.write("# Test Runner") | |
with st.expander("Pricing Fact Tests"): | |
display_pricing_fact_test() | |
with st.expander("Custom Tests"): | |
display_custom_test() | |