Spaces:
Runtime error
Runtime error
File size: 6,183 Bytes
4f07f72 c319c31 ab87be2 c319c31 b897a48 ab87be2 c319c31 bb7db2c ab87be2 c319c31 5f6f1d0 c319c31 f3f6cf6 c319c31 f9e1dd5 fc8884e c319c31 bb7db2c fc8884e c319c31 5f6f1d0 c319c31 5f6f1d0 c319c31 5f6f1d0 c319c31 f3f6cf6 c319c31 f9e1dd5 fc8884e c319c31 b897a48 bb7db2c b897a48 c319c31 bb7db2c fc8884e bb7db2c c319c31 bb7db2c c319c31 bb7db2c c319c31 ab87be2 c319c31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
"""
This page allows users to run batches of tests against the architectures. It allows the asking of
general questions from the question set, and also specific pricing questions which are designed
to test and demonstrate the architectures' ability to learn exact facts.
"""
import regex as re
import streamlit as st
from pandas import DataFrame
from random import sample
from src.architectures import *
from src.common import generate_group_tag
from src.datatypes import *
from src.testing import TestGenerator, batch_test
from src.st_helpers import st_setup
def display_custom_test():
"""
Write to the UI the ability to run some customised tests
"""
st.write("## Run a new custom test")
st.write("### Comment:")
comment = st.text_input("Optional comment for the test", key="custom_test_comment")
st.write("### Architectures to include:")
selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="custom_test_archs")
st.write("### Number of questions to ask:")
q_count = st.slider(label="Number of questions", min_value=1, max_value=TestGenerator.question_count(), step=1, key="custom_q_count")
st.write("### Number of threads to use for testing:")
thread_count = st.slider(label="Number of threads", min_value=1, max_value=64, step=1, value=16, key="custom_thread_slider")
st.write("### Tag:")
tag = generate_group_tag()
st.write(f'Test will be tagged as "{tag}" - record this for easy searching later')
total_tests = len(selected_archs) * q_count
st.write("### Run:")
st.write(f"**{total_tests}** total tests will be run")
if st.button("**Run**", disabled=(total_tests == 0), key="custom_test_button"):
with st.spinner():
questions = TestGenerator.get_random_questions(q_count)
batch_test(questions=questions, architectures=selected_archs,
trace_tags=[tag, "TestRunner"], trace_comment=comment,
num_workers=thread_count)
def display_pricing_fact_test():
"""
Write to the UI the ability to run some of the pricing fact test - asking the
architectures to state the prices for given products and calculating the resuting accuracy
"""
def get_question_price_pairs():
"""
Returns the complete list of pricing questions along with the correct prices
"""
DataLoader.load_data()
pairs = []
for p in Product.all.values():
price = p.price
product_name = p.name
category_name = p.category.lower_singular_name
if category_name == "tv":
category_name = "TV"
question = f'How much is the {product_name} {category_name}?'
pairs.append((question, price))
return pairs
def get_price_from_response(response: str) -> float:
"""
Parses a string with regex to get a float value for comparison to what the
architecture said the price was
"""
prices = re.findall('\$[,\d]+\.\d\d', response)
if len(prices) == 0:
return -0.1
return float(prices[0][1:].replace(',',''))
st.write("## Run a pricing fact test")
st.write("### Comment:")
comment = st.text_input("Optional comment for the test", key="pricing_test_comment")
st.write("### Architectures to include:")
selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="pricing_test_archs")
question_price_pairs = get_question_price_pairs()
st.write("### Number of questions to ask:")
q_count = st.slider(label="Number of questions", min_value=1, max_value=len(question_price_pairs), step=1, key="pricing_q_count")
st.write("### Number of threads to use for testing:")
thread_count = st.slider(label="Number of threads", min_value=1, max_value=64, step=1, value=16, key="pricing_thread_slider")
st.write("### Tag:")
tag = generate_group_tag()
st.write(f'Test will be tagged as "{tag}" - record this for easy searching later')
total_tests = len(selected_archs) * q_count
st.write("### Run:")
st.write(f"**{total_tests}** total tests will be run")
if st.button("**Run**", disabled=(total_tests == 0), key="pricing_test_button"):
question_price_pairs = sample(question_price_pairs, k=q_count)
question_price_dict = {qpp[0]: qpp[1] for qpp in question_price_pairs}
questions = list(question_price_dict.keys())
answer_stats = {}
for arch_name in selected_archs:
answer_stats[arch_name] = [0, 0] # [correct, incorrect]
with st.spinner():
results: List[Tuple[str, str, str]] = batch_test(questions=questions, architectures=selected_archs,
trace_tags=[tag, "TestRunner"], trace_comment=comment,
num_workers=thread_count)
for arch, query, response in results:
target_price = question_price_dict[query]
answer_price = get_price_from_response(response)
if target_price == answer_price:
answer_stats[arch][0] += 1
else:
answer_stats[arch][1] += 1
table_data = []
for arch_name in selected_archs:
correct = answer_stats[arch_name][0]
incorrect = answer_stats[arch_name][1]
total = correct + incorrect
percent_correct = round(correct / total * 100, 1)
table_data.append([arch_name, correct, incorrect, total, f'{percent_correct:.1f}%'])
df = DataFrame(table_data, columns=['Architecture', 'Correct', 'Incorrect', 'Total', '% Correct'])
st.table(df.assign(no_index='').set_index('no_index'))
if Architecture.architectures is None:
Architecture.load_architectures()
if st_setup('LLM Arch'):
st.write("# Test Runner")
with st.expander("Pricing Fact Tests"):
display_pricing_fact_test()
with st.expander("Custom Tests"):
display_custom_test()
|