Spaces:
Runtime error
Runtime error
File size: 5,418 Bytes
c319c31 ab87be2 c319c31 b897a48 ab87be2 c319c31 bb7db2c ab87be2 c319c31 f3f6cf6 c319c31 f9e1dd5 fc8884e c319c31 bb7db2c fc8884e c319c31 f3f6cf6 c319c31 f9e1dd5 fc8884e c319c31 b897a48 bb7db2c b897a48 c319c31 bb7db2c fc8884e bb7db2c c319c31 bb7db2c c319c31 bb7db2c c319c31 ab87be2 c319c31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import regex as re
import streamlit as st
from pandas import DataFrame
from random import sample
from src.architectures import *
from src.common import generate_group_tag
from src.datatypes import *
from src.testing import TestGenerator, batch_test
from src.st_helpers import st_setup
# Componentise different test options
def display_custom_test():
st.write("## Run a new custom test")
st.write("### Comment:")
comment = st.text_input("Optional comment for the test", key="custom_test_comment")
st.write("### Architectures to include:")
selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="custom_test_archs")
st.write("### Number of questions to ask:")
q_count = st.slider(label="Number of questions", min_value=1, max_value=TestGenerator.question_count(), step=1, key="custom_q_count")
st.write("### Number of threads to use for testing:")
thread_count = st.slider(label="Number of threads", min_value=1, max_value=64, step=1, value=16, key="custom_thread_slider")
st.write("### Tag:")
tag = generate_group_tag()
st.write(f'Test will be tagged as "{tag}" - record this for easy searching later')
total_tests = len(selected_archs) * q_count
st.write("### Run:")
st.write(f"**{total_tests}** total tests will be run")
if st.button("**Run**", disabled=(total_tests == 0), key="custom_test_button"):
with st.spinner():
questions = TestGenerator.get_random_questions(q_count)
batch_test(questions=questions, architectures=selected_archs,
trace_tags=[tag, "TestRunner"], trace_comment=comment,
num_workers=thread_count)
def display_pricing_fact_test():
def get_question_price_pairs():
DataLoader.load_data()
pairs = []
for p in Product.all.values():
price = p.price
product_name = p.name
category_name = p.category.lower_singular_name
if category_name == "tv":
category_name = "TV"
question = f'How much is the {product_name} {category_name}?'
pairs.append((question, price))
return pairs
def get_price_from_response(response: str) -> float:
prices = re.findall('\$[,\d]+\.\d\d', response)
if len(prices) == 0:
return -0.1
return float(prices[0][1:].replace(',',''))
st.write("## Run a pricing fact test")
st.write("### Comment:")
comment = st.text_input("Optional comment for the test", key="pricing_test_comment")
st.write("### Architectures to include:")
selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="pricing_test_archs")
question_price_pairs = get_question_price_pairs()
st.write("### Number of questions to ask:")
q_count = st.slider(label="Number of questions", min_value=1, max_value=len(question_price_pairs), step=1, key="pricing_q_count")
st.write("### Number of threads to use for testing:")
thread_count = st.slider(label="Number of threads", min_value=1, max_value=64, step=1, value=16, key="pricing_thread_slider")
st.write("### Tag:")
tag = generate_group_tag()
st.write(f'Test will be tagged as "{tag}" - record this for easy searching later')
total_tests = len(selected_archs) * q_count
st.write("### Run:")
st.write(f"**{total_tests}** total tests will be run")
if st.button("**Run**", disabled=(total_tests == 0), key="pricing_test_button"):
question_price_pairs = sample(question_price_pairs, k=q_count)
question_price_dict = {qpp[0]: qpp[1] for qpp in question_price_pairs}
questions = list(question_price_dict.keys())
answer_stats = {}
for arch_name in selected_archs:
answer_stats[arch_name] = [0, 0] # [correct, incorrect]
with st.spinner():
results: List[Tuple[str, str, str]] = batch_test(questions=questions, architectures=selected_archs,
trace_tags=[tag, "TestRunner"], trace_comment=comment,
num_workers=thread_count)
for arch, query, response in results:
target_price = question_price_dict[query]
answer_price = get_price_from_response(response)
if target_price == answer_price:
answer_stats[arch][0] += 1
else:
answer_stats[arch][1] += 1
table_data = []
for arch_name in selected_archs:
correct = answer_stats[arch_name][0]
incorrect = answer_stats[arch_name][1]
total = correct + incorrect
percent_correct = round(correct / total * 100, 1)
table_data.append([arch_name, correct, incorrect, total, f'{percent_correct:.1f}%'])
df = DataFrame(table_data, columns=['Architecture', 'Correct', 'Incorrect', 'Total', '% Correct'])
st.table(df.assign(no_index='').set_index('no_index'))
if Architecture.architectures is None:
Architecture.load_architectures()
if st_setup('LLM Arch'):
st.write("# Test Runner")
with st.expander("Pricing Fact Tests"):
display_pricing_fact_test()
with st.expander("Custom Tests"):
display_custom_test()
|