File size: 6,183 Bytes
4f07f72
 
 
 
 
 
c319c31
ab87be2
 
c319c31
b897a48
ab87be2
 
c319c31
bb7db2c
ab87be2
 
 
c319c31
5f6f1d0
 
 
c319c31
 
 
 
 
 
 
 
f3f6cf6
c319c31
f9e1dd5
fc8884e
 
c319c31
 
 
 
 
 
 
 
bb7db2c
 
 
fc8884e
 
c319c31
 
 
5f6f1d0
 
 
 
c319c31
5f6f1d0
 
 
c319c31
 
 
 
 
 
 
 
 
 
 
 
 
5f6f1d0
 
 
 
c319c31
 
 
 
 
 
 
 
 
 
 
 
 
 
f3f6cf6
c319c31
f9e1dd5
fc8884e
 
c319c31
 
 
 
 
 
 
 
b897a48
bb7db2c
 
b897a48
c319c31
 
bb7db2c
 
 
 
fc8884e
 
bb7db2c
 
 
 
 
c319c31
bb7db2c
 
c319c31
 
 
 
 
 
 
 
bb7db2c
c319c31
 
ab87be2
 
 
 
c319c31
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
This page allows users to run batches of tests against the architectures.  It allows the asking of
general questions from the question set, and also specific pricing questions which are designed
to test and demonstrate the architectures' ability to learn exact facts.
"""

import regex as re
import streamlit as st

from pandas import DataFrame
from random import sample
from src.architectures import *
from src.common import generate_group_tag
from src.datatypes import *
from src.testing import TestGenerator, batch_test
from src.st_helpers import st_setup


def display_custom_test():
    """
    Write to the UI the ability to run some customised tests
    """
    st.write("## Run a new custom test")
    st.write("### Comment:")
    comment = st.text_input("Optional comment for the test", key="custom_test_comment")

    st.write("### Architectures to include:")
    selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="custom_test_archs")

    st.write("### Number of questions to ask:")
    q_count = st.slider(label="Number of questions", min_value=1, max_value=TestGenerator.question_count(), step=1, key="custom_q_count")

    st.write("### Number of threads to use for testing:")
    thread_count = st.slider(label="Number of threads", min_value=1, max_value=64, step=1, value=16, key="custom_thread_slider")

    st.write("### Tag:")
    tag = generate_group_tag()
    st.write(f'Test will be tagged as "{tag}" - record this for easy searching later')

    total_tests = len(selected_archs) * q_count
    st.write("### Run:")
    st.write(f"**{total_tests}** total tests will be run")
    if st.button("**Run**", disabled=(total_tests == 0), key="custom_test_button"):
        with st.spinner():
            questions = TestGenerator.get_random_questions(q_count)
            batch_test(questions=questions, architectures=selected_archs,
                       trace_tags=[tag, "TestRunner"], trace_comment=comment,
                       num_workers=thread_count)


def display_pricing_fact_test():
    """
    Write to the UI the ability to run some of the pricing fact test - asking the
    architectures to state the prices for given products and calculating the resuting accuracy
    """
    def get_question_price_pairs():
        """
        Returns the complete list of pricing questions along with the correct prices
        """
        DataLoader.load_data()
        pairs = []
        for p in Product.all.values():
            price = p.price
            product_name = p.name
            category_name = p.category.lower_singular_name
            if category_name == "tv":
                category_name = "TV"
            question = f'How much is the {product_name} {category_name}?'
            pairs.append((question, price))
        return pairs

    def get_price_from_response(response: str) -> float:
        """
        Parses a string with regex to get a float value for comparison to what the
        architecture said the price was
        """
        prices = re.findall('\$[,\d]+\.\d\d', response)
        if len(prices) == 0:
            return -0.1
        return float(prices[0][1:].replace(',',''))

    st.write("## Run a pricing fact test")
    st.write("### Comment:")
    comment = st.text_input("Optional comment for the test", key="pricing_test_comment")

    st.write("### Architectures to include:")
    selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="pricing_test_archs")

    question_price_pairs = get_question_price_pairs()
    st.write("### Number of questions to ask:")
    q_count = st.slider(label="Number of questions", min_value=1, max_value=len(question_price_pairs), step=1, key="pricing_q_count")

    st.write("### Number of threads to use for testing:")
    thread_count = st.slider(label="Number of threads", min_value=1, max_value=64, step=1, value=16, key="pricing_thread_slider")

    st.write("### Tag:")
    tag = generate_group_tag()
    st.write(f'Test will be tagged as "{tag}" - record this for easy searching later')

    total_tests = len(selected_archs) * q_count
    st.write("### Run:")
    st.write(f"**{total_tests}** total tests will be run")
    if st.button("**Run**", disabled=(total_tests == 0), key="pricing_test_button"):
        question_price_pairs = sample(question_price_pairs, k=q_count)
        question_price_dict = {qpp[0]: qpp[1] for qpp in question_price_pairs}
        questions = list(question_price_dict.keys())

        answer_stats = {}
        for arch_name in selected_archs:
            answer_stats[arch_name] = [0, 0]  # [correct, incorrect]

        with st.spinner():
            results: List[Tuple[str, str, str]] = batch_test(questions=questions, architectures=selected_archs,
                                                             trace_tags=[tag, "TestRunner"], trace_comment=comment,
                                                             num_workers=thread_count)
            for arch, query, response in results:
                target_price = question_price_dict[query]
                answer_price = get_price_from_response(response)
                if target_price == answer_price:
                    answer_stats[arch][0] += 1
                else:
                    answer_stats[arch][1] += 1

        table_data = []
        for arch_name in selected_archs:
            correct = answer_stats[arch_name][0]
            incorrect = answer_stats[arch_name][1]
            total = correct + incorrect
            percent_correct = round(correct / total * 100, 1)
            table_data.append([arch_name, correct, incorrect, total, f'{percent_correct:.1f}%'])
        df = DataFrame(table_data, columns=['Architecture', 'Correct', 'Incorrect', 'Total', '% Correct'])
        st.table(df.assign(no_index='').set_index('no_index'))


if Architecture.architectures is None:
    Architecture.load_architectures()

if st_setup('LLM Arch'):
    st.write("# Test Runner")
    with st.expander("Pricing Fact Tests"):
        display_pricing_fact_test()
    with st.expander("Custom Tests"):
        display_custom_test()