File size: 5,418 Bytes
c319c31
ab87be2
 
c319c31
b897a48
ab87be2
 
c319c31
bb7db2c
ab87be2
 
 
c319c31
 
 
 
 
 
 
 
 
 
f3f6cf6
c319c31
f9e1dd5
fc8884e
 
c319c31
 
 
 
 
 
 
 
bb7db2c
 
 
fc8884e
 
c319c31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3f6cf6
c319c31
f9e1dd5
fc8884e
 
c319c31
 
 
 
 
 
 
 
b897a48
bb7db2c
 
b897a48
c319c31
 
bb7db2c
 
 
 
fc8884e
 
bb7db2c
 
 
 
 
c319c31
bb7db2c
 
c319c31
 
 
 
 
 
 
 
bb7db2c
c319c31
 
ab87be2
 
 
 
c319c31
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import regex as re
import streamlit as st

from pandas import DataFrame
from random import sample
from src.architectures import *
from src.common import generate_group_tag
from src.datatypes import *
from src.testing import TestGenerator, batch_test
from src.st_helpers import st_setup


# Componentise different test options
def display_custom_test():
    st.write("## Run a new custom test")
    st.write("### Comment:")
    comment = st.text_input("Optional comment for the test", key="custom_test_comment")

    st.write("### Architectures to include:")
    selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="custom_test_archs")

    st.write("### Number of questions to ask:")
    q_count = st.slider(label="Number of questions", min_value=1, max_value=TestGenerator.question_count(), step=1, key="custom_q_count")

    st.write("### Number of threads to use for testing:")
    thread_count = st.slider(label="Number of threads", min_value=1, max_value=64, step=1, value=16, key="custom_thread_slider")

    st.write("### Tag:")
    tag = generate_group_tag()
    st.write(f'Test will be tagged as "{tag}" - record this for easy searching later')

    total_tests = len(selected_archs) * q_count
    st.write("### Run:")
    st.write(f"**{total_tests}** total tests will be run")
    if st.button("**Run**", disabled=(total_tests == 0), key="custom_test_button"):
        with st.spinner():
            questions = TestGenerator.get_random_questions(q_count)
            batch_test(questions=questions, architectures=selected_archs,
                       trace_tags=[tag, "TestRunner"], trace_comment=comment,
                       num_workers=thread_count)


def display_pricing_fact_test():
    def get_question_price_pairs():
        DataLoader.load_data()
        pairs = []
        for p in Product.all.values():
            price = p.price
            product_name = p.name
            category_name = p.category.lower_singular_name
            if category_name == "tv":
                category_name = "TV"
            question = f'How much is the {product_name} {category_name}?'
            pairs.append((question, price))
        return pairs

    def get_price_from_response(response: str) -> float:
        prices = re.findall('\$[,\d]+\.\d\d', response)
        if len(prices) == 0:
            return -0.1
        return float(prices[0][1:].replace(',',''))

    st.write("## Run a pricing fact test")
    st.write("### Comment:")
    comment = st.text_input("Optional comment for the test", key="pricing_test_comment")

    st.write("### Architectures to include:")
    selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="pricing_test_archs")

    question_price_pairs = get_question_price_pairs()
    st.write("### Number of questions to ask:")
    q_count = st.slider(label="Number of questions", min_value=1, max_value=len(question_price_pairs), step=1, key="pricing_q_count")

    st.write("### Number of threads to use for testing:")
    thread_count = st.slider(label="Number of threads", min_value=1, max_value=64, step=1, value=16, key="pricing_thread_slider")

    st.write("### Tag:")
    tag = generate_group_tag()
    st.write(f'Test will be tagged as "{tag}" - record this for easy searching later')

    total_tests = len(selected_archs) * q_count
    st.write("### Run:")
    st.write(f"**{total_tests}** total tests will be run")
    if st.button("**Run**", disabled=(total_tests == 0), key="pricing_test_button"):
        question_price_pairs = sample(question_price_pairs, k=q_count)
        question_price_dict = {qpp[0]: qpp[1] for qpp in question_price_pairs}
        questions = list(question_price_dict.keys())

        answer_stats = {}
        for arch_name in selected_archs:
            answer_stats[arch_name] = [0, 0]  # [correct, incorrect]

        with st.spinner():
            results: List[Tuple[str, str, str]] = batch_test(questions=questions, architectures=selected_archs,
                                                             trace_tags=[tag, "TestRunner"], trace_comment=comment,
                                                             num_workers=thread_count)
            for arch, query, response in results:
                target_price = question_price_dict[query]
                answer_price = get_price_from_response(response)
                if target_price == answer_price:
                    answer_stats[arch][0] += 1
                else:
                    answer_stats[arch][1] += 1

        table_data = []
        for arch_name in selected_archs:
            correct = answer_stats[arch_name][0]
            incorrect = answer_stats[arch_name][1]
            total = correct + incorrect
            percent_correct = round(correct / total * 100, 1)
            table_data.append([arch_name, correct, incorrect, total, f'{percent_correct:.1f}%'])
        df = DataFrame(table_data, columns=['Architecture', 'Correct', 'Incorrect', 'Total', '% Correct'])
        st.table(df.assign(no_index='').set_index('no_index'))


if Architecture.architectures is None:
    Architecture.load_architectures()

if st_setup('LLM Arch'):
    st.write("# Test Runner")
    with st.expander("Pricing Fact Tests"):
        display_pricing_fact_test()
    with st.expander("Custom Tests"):
        display_custom_test()