""" This page allows users to run batches of tests against the architectures. It allows the asking of general questions from the question set, and also specific pricing questions which are designed to test and demonstrate the architectures' ability to learn exact facts. """ import regex as re import streamlit as st from pandas import DataFrame from random import sample from src.architectures import * from src.common import generate_group_tag from src.datatypes import * from src.testing import TestGenerator, batch_test from src.st_helpers import st_setup def display_custom_test(): """ Write to the UI the ability to run some customised tests """ st.write("## Run a new custom test") st.write("### Comment:") comment = st.text_input("Optional comment for the test", key="custom_test_comment") st.write("### Architectures to include:") selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="custom_test_archs") st.write("### Number of questions to ask:") q_count = st.slider(label="Number of questions", min_value=1, max_value=TestGenerator.question_count(), step=1, key="custom_q_count") st.write("### Number of threads to use for testing:") thread_count = st.slider(label="Number of threads", min_value=1, max_value=64, step=1, value=16, key="custom_thread_slider") st.write("### Tag:") tag = generate_group_tag() st.write(f'Test will be tagged as "{tag}" - record this for easy searching later') total_tests = len(selected_archs) * q_count st.write("### Run:") st.write(f"**{total_tests}** total tests will be run") if st.button("**Run**", disabled=(total_tests == 0), key="custom_test_button"): with st.spinner(): questions = TestGenerator.get_random_questions(q_count) batch_test(questions=questions, architectures=selected_archs, trace_tags=[tag, "TestRunner"], trace_comment=comment, num_workers=thread_count) def display_pricing_fact_test(): """ Write to the UI the ability to run some of the pricing fact test - asking the architectures to state the prices for given products and calculating the resuting accuracy """ def get_question_price_pairs(): """ Returns the complete list of pricing questions along with the correct prices """ DataLoader.load_data() pairs = [] for p in Product.all.values(): price = p.price product_name = p.name category_name = p.category.lower_singular_name if category_name == "tv": category_name = "TV" question = f'How much is the {product_name} {category_name}?' pairs.append((question, price)) return pairs def get_price_from_response(response: str) -> float: """ Parses a string with regex to get a float value for comparison to what the architecture said the price was """ prices = re.findall('\$[,\d]+\.\d\d', response) if len(prices) == 0: return -0.1 return float(prices[0][1:].replace(',','')) st.write("## Run a pricing fact test") st.write("### Comment:") comment = st.text_input("Optional comment for the test", key="pricing_test_comment") st.write("### Architectures to include:") selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="pricing_test_archs") question_price_pairs = get_question_price_pairs() st.write("### Number of questions to ask:") q_count = st.slider(label="Number of questions", min_value=1, max_value=len(question_price_pairs), step=1, key="pricing_q_count") st.write("### Number of threads to use for testing:") thread_count = st.slider(label="Number of threads", min_value=1, max_value=64, step=1, value=16, key="pricing_thread_slider") st.write("### Tag:") tag = generate_group_tag() st.write(f'Test will be tagged as "{tag}" - record this for easy searching later') total_tests = len(selected_archs) * q_count st.write("### Run:") st.write(f"**{total_tests}** total tests will be run") if st.button("**Run**", disabled=(total_tests == 0), key="pricing_test_button"): question_price_pairs = sample(question_price_pairs, k=q_count) question_price_dict = {qpp[0]: qpp[1] for qpp in question_price_pairs} questions = list(question_price_dict.keys()) answer_stats = {} for arch_name in selected_archs: answer_stats[arch_name] = [0, 0] # [correct, incorrect] with st.spinner(): results: List[Tuple[str, str, str]] = batch_test(questions=questions, architectures=selected_archs, trace_tags=[tag, "TestRunner"], trace_comment=comment, num_workers=thread_count) for arch, query, response in results: target_price = question_price_dict[query] answer_price = get_price_from_response(response) if target_price == answer_price: answer_stats[arch][0] += 1 else: answer_stats[arch][1] += 1 table_data = [] for arch_name in selected_archs: correct = answer_stats[arch_name][0] incorrect = answer_stats[arch_name][1] total = correct + incorrect percent_correct = round(correct / total * 100, 1) table_data.append([arch_name, correct, incorrect, total, f'{percent_correct:.1f}%']) df = DataFrame(table_data, columns=['Architecture', 'Correct', 'Incorrect', 'Total', '% Correct']) st.table(df.assign(no_index='').set_index('no_index')) if Architecture.architectures is None: Architecture.load_architectures() if st_setup('LLM Arch'): st.write("# Test Runner") with st.expander("Pricing Fact Tests"): display_pricing_fact_test() with st.expander("Custom Tests"): display_custom_test()