Spaces:
Runtime error
Runtime error
from __future__ import annotations # For self-referencing annotations | |
import json | |
import os | |
import shutil | |
import sqlite3 | |
import sys | |
from huggingface_hub import Repository | |
from queue import Queue | |
from random import sample | |
from threading import Thread | |
from typing import Dict, List, Optional, Tuple | |
from src.architectures import Architecture, ArchitectureRequest, LogWorker | |
from src.common import data_dir | |
class ArchitectureTestWorker(Thread): | |
""" | |
This class is worker which takes a test request off the queue and passes | |
it to an architecture for execution. Used to multi-thread the testing process | |
for speed as there is a tonne of i/o blocking waiting for the LLM | |
""" | |
def __init__(self, work_queue: Queue, worker_name: str, trace_tags: List[str], trace_comment: str): | |
Thread.__init__(self) | |
self.work_queue = work_queue | |
self.worker_name = worker_name | |
self.trace_tags = trace_tags | |
self.trace_comment = trace_comment | |
def run(self): | |
running: bool = True | |
while running: | |
arch, request = self.work_queue.get() | |
try: | |
if arch is None: # None passed to signal end of test requests | |
running = False | |
else: | |
print(f'{self.worker_name} running "{request.request}" through {arch}') | |
architecture = Architecture.get_architecture(arch) | |
architecture(request, trace_tags=self.trace_tags, trace_comment=self.trace_comment) | |
finally: | |
self.work_queue.task_done() | |
def batch_test(questions: List[str], architectures: List[str], trace_comment: str = "", | |
trace_tags: List[str] = [], num_workers: int = 16) -> List[Tuple[str, str, str]]: | |
""" | |
Creates a worked pool and dispatches the questions, returnin the answers per architecture, question | |
:param questions: A list of the questions | |
:param architectures: A list of the names of the architectures | |
:param num_workers: The number of works to run | |
:return: A list of Tuples of (arch_name, question, answer) | |
""" | |
queue = Queue() | |
question_record: Dict[Tuple[str, str], ArchitectureRequest] = {} | |
for q in questions: | |
for a in architectures: | |
request = ArchitectureRequest(q) | |
question_record[(a, q)] = request | |
queue.put((a, request)) | |
for i in range(num_workers): | |
worker = ArchitectureTestWorker(work_queue=queue, worker_name=f'Worker {i+1}', | |
trace_tags=trace_tags, trace_comment=trace_comment) | |
worker.daemon = True | |
worker.start() | |
queue.put((None, None)) # Flag to finish | |
queue.join() | |
# Repackage and return just the list of (arch_name, question, answer) | |
return [(k[0], k[1], v.response) for k, v in question_record.items()] | |
class TestGenerator: | |
""" | |
Wrapper class to hold testing questions and serve up examples | |
""" | |
questions: List[str] = None | |
def load_questions(cls, reload=False) -> None: | |
""" | |
Load the available questions from the json file. | |
Default to not re-loading if already done, but allow for the option to do so | |
""" | |
if cls.questions is not None and not reload: | |
return | |
question_file = os.path.join(data_dir, 'json', 'test_questions.json') | |
with open(question_file, 'r') as f: | |
question_json = json.load(f) | |
cls.questions = question_json['questions'] | |
def question_count(cls) -> int: | |
""" | |
The total number of questions in the question set | |
""" | |
cls.load_questions() | |
return len(cls.questions) | |
def get_random_questions(cls, n: int): | |
""" | |
Return n random questions | |
""" | |
cls.load_questions() | |
return sample(cls.questions, k=n) | |
class ArchitectureRequestRecord: | |
""" | |
Representation of the test data associated with each invocation of an architecture | |
""" | |
all: List[ArchitectureRequestRecord] = None | |
class ArchStep: | |
""" | |
Inner class to just hold this data | |
""" | |
def __init__(self, name: str, start: int, end: int): | |
self.name = name | |
self.start = start | |
self.end = end | |
self.elapsed = end - start | |
def __init__(self, arch: str, request: str, response: str, response_len: int, start: int, end: int, | |
elapsed: int, tags: List[str], test_group: Optional[str], | |
comment: str, steps: List[ArchitectureRequestRecord.ArchStep]): | |
self.arch = arch | |
self.request = request | |
self.response = response | |
self.response_len = response_len | |
self.start = start | |
self.end = end | |
self.elapsed = elapsed | |
self.tags = tags | |
self.test_group = test_group | |
self.comment = comment | |
self.steps = steps | |
def from_dict(cls, test: Dict) -> ArchitectureRequestRecord: | |
arch = test['architecture'] | |
request = test['request']['request_evolution'][0] | |
response = "" | |
if len(test['request']['response_evolution']) == 0: | |
response_len = 0 | |
else: | |
response_len = len(test['request']['response_evolution'][-1]) | |
response = test['request']['response_evolution'][-1] | |
start = test['trace']['steps'][0]['start_ms'] | |
end = test['trace']['steps'][-1]['end_ms'] | |
elapsed = end - start | |
tags = test['test_tags'] | |
test_group = None | |
for tag in tags: | |
if tag.startswith("TestGroup"): | |
test_group = tag | |
comment = test['test_comment'] | |
steps = [] | |
for s in test['trace']['steps']: | |
steps.append(ArchitectureRequestRecord.ArchStep(s['name'], s['start_ms'], s['end_ms'])) | |
return ArchitectureRequestRecord(arch, request, response, response_len, start, end, elapsed, tags, test_group, comment, steps) | |
def load_all(cls, reload=False) -> None: | |
""" | |
Load all the traces from json trace log | |
""" | |
if cls.all is None or reload: | |
records = [] | |
test_traces = Architecture.get_trace_records() | |
for trace in test_traces: | |
records.append(ArchitectureRequestRecord.from_dict(trace)) | |
cls.all = records | |
class TestGroup: | |
""" | |
A class representing a single batch run of tests from the UI. Identified by the tag | |
which was assigned from the UI when the test was run, and including summary items | |
(start, end, elapse) for convenience | |
""" | |
all: Dict[str, TestGroup] = None | |
def __init__(self, test_group:str): | |
self.arch_request_records: List[ArchitectureRequestRecord] = [] | |
self.test_group = test_group | |
self.comment = None | |
self.start = None | |
self.end = None | |
self.elapsed = None | |
self.architectures = set() | |
def num_archs(self) -> int: | |
""" | |
The number of LLM Architectures which were included in this test run from the UI | |
""" | |
return len(self.architectures) | |
def num_tests(self) -> int: | |
""" | |
The total number of Architecture tests (inferences) done in this test run from the UI | |
""" | |
return len(self.arch_request_records) | |
def num_tests_per_arch(self) -> int: | |
""" | |
The calculated number of tests run through each architecture (simple divide as the UI | |
forces each architecture to get the same number of requests) | |
""" | |
# Should always be an even number but cast to int just in case | |
return int(self.num_tests / self.num_archs) | |
def arch_request_records_by_arch(self) -> Dict[List[ArchitectureRequestRecord]]: | |
""" | |
Get all the tests ArchitectureRequestRecords grouped by the architecture. | |
:return: dict keyed by the architecture name containing a list of ArchitectureRequestRecords | |
detailing the tests run through that architecture. Note - the keys are intended to be used for | |
display purposes - attempting to use them to load the original architecture will be | |
dependent on the availability of that architecture at look up time and changes in architecture | |
config could cause that lookup to fail (i.e. the tested architecture is no longer configured). | |
""" | |
grouped = {} | |
for arr in self.arch_request_records: | |
if arr.arch not in grouped: | |
grouped[arr.arch] = [] | |
grouped[arr.arch].append(arr) | |
return grouped | |
def summary_stats_by_arch(self) -> List[Dict]: | |
""" | |
Get a pack of statistics for use in the UI, detailing this TestGroup. | |
:return: a list, sorted by architecture name, of statistics per architecture. Each list item | |
is a dict of information (arch_name, elapsed[list of elapsed times in ms], | |
response_len[list of the lengthe of the final response in characters], steps[list of the | |
individual architecture steps each containing dict of name, mean_elapsed(ms)]) | |
""" | |
arch_records = self.arch_request_records_by_arch() | |
arch_names = list(arch_records.keys()) | |
arch_names.sort() | |
stats = [] | |
for a in arch_names: | |
stat_pack = {'arch_name': a, 'elapsed': [rec.elapsed for rec in arch_records[a]], | |
'response_len': [rec.response_len for rec in arch_records[a]], 'steps': [], | |
'q_and_a': {}} | |
for rec in arch_records[a]: | |
stat_pack['q_and_a'][rec.request] = rec.response | |
for i in range(len(arch_records[a][0].steps)): | |
stat_pack['steps'].append({'step_name': arch_records[a][0].steps[i].name}) | |
num_recs = len(arch_records[a]) | |
total_elapsed = 0 | |
for j in range(num_recs): | |
total_elapsed += arch_records[a][j].steps[i].elapsed | |
stat_pack['steps'][-1]['mean_elapsed'] = total_elapsed / num_recs | |
stats.append(stat_pack) | |
return stats | |
def add_record(self, arr: ArchitectureRequestRecord) -> None: | |
""" | |
Add an ArchitectureRequestRecord into this test group. Update the | |
TestGroup level start, end and elapsed with the new data | |
""" | |
if arr.test_group != self.test_group: | |
raise ValueError("Attempted to group a test record into the wrong group") | |
self.arch_request_records.append(arr) | |
self.architectures.add(arr.arch) | |
if self.comment is None: | |
self.comment = arr.comment | |
if self.start is None or self.start > arr.start: | |
self.start = arr.start | |
if self.end is None or self.end < arr.end: | |
self.end = arr.end | |
self.elapsed = self.end - self.start | |
def load_json_test_groups(cls, reload: bool = False) -> List[TestGroup]: | |
""" | |
Load all the test groups from the local json file, reloading from the HF Hub if requested | |
""" | |
ArchitectureRequestRecord.load_all(reload=reload) | |
test_groups: Dict[str, TestGroup] = {} | |
for arr in ArchitectureRequestRecord.all: | |
if arr.test_group is not None: | |
if arr.test_group not in test_groups: | |
test_groups[arr.test_group] = TestGroup(arr.test_group) | |
test_groups[arr.test_group].add_record(arr) | |
return list(test_groups.values()) | |
def load_db_test_groups(cls) -> List[TestGroup]: | |
""" | |
Load all the test groups from the DataBase | |
""" | |
db_file = os.path.join(data_dir, 'sqlite', 'test_records.db') | |
con = sqlite3.connect(db_file) | |
cur = con.cursor() | |
sql = "SELECT id, test_group from test_groups" | |
cur.execute(sql) | |
tg_id_names = [(r[0], r[1]) for r in cur.fetchall()] | |
test_groups: List[TestGroup] = [] | |
for tg_id, tg_name in tg_id_names: | |
tg = TestGroup(tg_name) | |
sql = f"SELECT id, arch_name, request, response, response_len, start, end, comment FROM arch_requests WHERE test_group_id={tg_id}" | |
cur.execute(sql) | |
arch_requests = [(r[0], r[1], r[2], r[3], r[4], r[5], r[6], r[7]) for r in cur.fetchall()] | |
for ar_id, ar_arch_name, ar_req, ar_resp, ar_resp_len, ar_start, ar_end, ar_comment in arch_requests: | |
sql = f"SELECT name, start, end FROM arch_req_steps WHERE arch_req_id={ar_id}" | |
cur.execute(sql) | |
steps = [ArchitectureRequestRecord.ArchStep(r[0], r[1], r[2]) for r in cur.fetchall()] | |
arch_req_record = ArchitectureRequestRecord(ar_arch_name, ar_req, ar_resp, ar_resp_len, ar_start, ar_end, (ar_end - ar_start), [tg_name], tg_name, f"(DB persisted) {ar_comment}", steps) | |
tg.add_record(arch_req_record) | |
test_groups.append(tg) | |
return test_groups | |
def force_load_all(cls): | |
""" | |
Convenience wrapper to allow a no parameter call to force the reload of the | |
TestGroups without any parameters, for the LogWorker callback | |
""" | |
cls.load_all(True) | |
def load_all(cls, reload: bool = False): | |
""" | |
Load all the available TestGroups, from both the json file and the DB | |
into the class variable - for efficiency do not reload unless requested | |
""" | |
if cls.force_load_all not in LogWorker.timeout_functions: | |
print("TestGroup adding forced refresh to LogWorker timeout") | |
LogWorker.timeout_functions.append(TestGroup.force_load_all) | |
if cls.all is None or reload: | |
working_test_groups = {} | |
json_tgs = cls.load_json_test_groups(reload=reload) | |
for test_group in json_tgs: | |
working_test_groups[test_group.test_group] = test_group | |
db_tgs = cls.load_db_test_groups() | |
for test_group in db_tgs: | |
working_test_groups[test_group.test_group] = test_group | |
cls.all = working_test_groups | |
def for_test_group_tag(cls, test_group_tag: str) -> TestGroup: | |
""" | |
Get a single TestGroup based on the test_group_tag which was assigned | |
when the test was run | |
""" | |
cls.load_all() | |
return cls.all[test_group_tag] | |
def move_test_records_to_db(hf_hub_token: str) -> None: | |
""" | |
This is an offline utility to move the transaction logs from | |
a flat file to a database. To keep things simpler transaction logs | |
are initially stored into a json file as a Hugging Face dataset, but | |
this can get cumbersome, so this utility will move the records | |
into an sqlite database. It can be run periodically just to move | |
them across. | |
""" | |
def download_latest_json_file(hf_hub_token: str) -> Repository: | |
""" | |
Wipe any local version of the json file and re-downlad from the HF Hub | |
""" | |
if os.path.exists(Architecture.trace_dir): | |
shutil.rmtree(Architecture.trace_dir) | |
return Repository(local_dir=Architecture.trace_dir, clone_from=Architecture.save_repo_url, token=hf_hub_token) | |
def create_local_db(): | |
""" | |
Create the local database if it does not exist | |
""" | |
db_file = os.path.join(data_dir, 'sqlite', 'test_records.db') | |
con = sqlite3.connect(db_file) | |
sql = "CREATE TABLE test_groups (id INTEGER PRIMARY KEY AUTOINCREMENT, test_group TEXT NOT NULL, start INTEGER NOT NULL, end INTEGER NOT NULL);" | |
con.execute(sql) | |
sql = "CREATE TABLE arch_requests (id INTEGER PRIMARY KEY AUTOINCREMENT, arch_name TEXT NOT NULL, request TEXT NOT NULL, response TEXT NOT NULL, response_len INTEGER NOT NULL, start INTEGER NOT NULL, end INTEGER NOT NULL, comment TEXT NOT NULL, test_group_id INTEGER NOT NULL, FOREIGN KEY (test_group_id) REFERENCES test_groups (id))" | |
con.execute(sql) | |
sql = "CREATE TABLE arch_req_steps (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, start INTEGER NOT NULL, end INTEGER NOT NULL, arch_req_id INTEGER NOT NULL, FOREIGN KEY (arch_req_id) REFERENCES arch_requests(id))" | |
con.execute(sql) | |
def get_local_db() -> sqlite3.Connection: | |
""" | |
Get a connection to the local database and create it if it is not already there | |
""" | |
db_file = os.path.join(data_dir, 'sqlite', 'test_records.db') | |
if not os.path.exists(db_file): | |
create_local_db() | |
return sqlite3.connect(db_file) | |
def load_test_group_to_db(test_group: TestGroup, con: sqlite3.Connection) -> None: | |
""" | |
Load a single TestGroup object into the DB, decomposing to the TestGroup, | |
ArchitectureRequest within that and ArchitectureRequestSteps within those | |
""" | |
cur = con.cursor() | |
sql = f'SELECT count(*) from test_groups where test_group ="{test_group.test_group}"' | |
cur.execute(sql) | |
tg_not_in_db = cur.fetchall()[0][0] == 0 | |
if tg_not_in_db: | |
sql = f'INSERT into test_groups (test_group, start, end) VALUES ("{test_group.test_group}", {test_group.start}, {test_group.end})' | |
tg_id = con.execute(sql).lastrowid | |
for arr in test_group.arch_request_records: | |
sql = f'INSERT INTO arch_requests (arch_name, request, response, response_len, start, end, comment, test_group_id) VALUES ("{arr.arch}", ?, ?, {arr.response_len}, {arr.start}, {arr.end}, "{arr.comment}", {tg_id})' | |
arr_id = con.execute(sql, (arr.request, arr.response)).lastrowid | |
for s in arr.steps: | |
sql= f'INSERT INTO arch_req_steps (name, start, end, arch_req_id) VALUES ("{s.name}", {s.start}, {s.end}, {arr_id})' | |
con.execute(sql) | |
con.commit() | |
else: | |
print(f"Warning TestGroup {test_group.test_group} was not added to the DB as it already existed there") | |
def load_all_test_groups_to_db(con: sqlite3.Connection) -> None: | |
""" | |
Load a list of TestGroups to the DB, one at a time | |
""" | |
TestGroup.load_all() | |
for tg in TestGroup.all.values(): | |
load_test_group_to_db(tg, con) | |
""" | |
Main control flow using utility nested functions above for better structure and readability | |
""" | |
download_latest_json_file(hf_hub_token) | |
conn = get_local_db() | |
load_all_test_groups_to_db(conn) | |
Architecture.wipe_trace(hf_hub_token) | |
print("REMINDER: need to commit the local sqlite file to make it available to the server") | |
if __name__ == "__main__": | |
# Expected to only be directly called for the json to db transfer - arg should be the HF token | |
hf_token = sys.argv[1] | |
move_test_records_to_db(hf_token) | |