alfraser commited on
Commit
c0a1e47
·
1 Parent(s): bb7db2c

Implemented single threaded worker on writing the logs to the json file for controlled access to the resource on the file system now we are multi-threading the tests.

Browse files
Files changed (2) hide show
  1. src/architectures.py +88 -40
  2. src/testing.py +9 -4
src/architectures.py CHANGED
@@ -14,6 +14,8 @@ import traceback
14
  from abc import ABC, abstractmethod
15
  from enum import Enum
16
  from huggingface_hub import Repository
 
 
17
  from time import time
18
  from typing import List, Optional, Dict
19
  from better_profanity import profanity
@@ -185,6 +187,91 @@ class ArchitectureComponent(ABC):
185
  return ""
186
 
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  class Architecture:
189
  """
190
  An architecture is built as a callable pipeline of steps. An
@@ -275,45 +362,6 @@ class Architecture:
275
  return a
276
  raise ValueError(f"Could not find an architecture named {name}")
277
 
278
- @classmethod
279
- def append_and_save_data_as_json(cls, data: Dict):
280
- if cls.save_repo is None and not cls.save_repo_load_error:
281
- try:
282
- hf_write_token = hf_api_token(write=True)
283
- cls.save_repo = Repository(local_dir=cls.trace_dir, clone_from=cls.save_repo_url, token=hf_write_token)
284
- except Exception as err:
285
- cls.save_repo_load_error = True
286
- print(f"Error connecting to the save repo {err} - persistence now disabled")
287
-
288
- if cls.save_repo is not None:
289
- with open(cls.trace_file, 'r') as f:
290
- test_json = json.load(f)
291
- test_json['tests'].append(data)
292
- with open(cls.trace_file, 'w') as f:
293
- json.dump(test_json, f, indent=2)
294
- cls.save_repo.push_to_hub()
295
-
296
- def attempt_request_and_trace_save(self, request: ArchitectureRequest, trace: ArchitectureTrace,
297
- trace_tags: List[str] = None, trace_comment: str = None) -> None:
298
- """
299
- Attempt to save a request and trace pair to a json store on huggingface datasets
300
- Catch any errors and simply print as non-fatal to functional flow
301
- """
302
- try:
303
- if trace_tags is None:
304
- trace_tags = []
305
- if trace_comment is None:
306
- trace_comment = ""
307
- save_dict = {
308
- 'architecture': self.name,
309
- 'request': request.as_dict(),
310
- 'trace': trace.as_dict(),
311
- 'test_tags': trace_tags,
312
- 'test_comment': trace_comment
313
- }
314
- self.append_and_save_data_as_json(save_dict)
315
- except Exception as err:
316
- print(f"Request / trace save failed {err}")
317
 
318
  def __init__(self,
319
  name: str,
@@ -355,7 +403,7 @@ class Architecture:
355
  trace.steps[-1].exception = err
356
  traceback.print_exc()
357
  break
358
- self.attempt_request_and_trace_save(request, trace, trace_tags, trace_comment)
359
  return trace
360
 
361
 
 
14
  from abc import ABC, abstractmethod
15
  from enum import Enum
16
  from huggingface_hub import Repository
17
+ from queue import Queue
18
+ from threading import Thread, Timer
19
  from time import time
20
  from typing import List, Optional, Dict
21
  from better_profanity import profanity
 
187
  return ""
188
 
189
 
190
+ class LogWorker(Thread):
191
+ instance = None
192
+ architectures = None
193
+ save_repo = None
194
+ save_repo_load_error = False
195
+ save_repo_url = "https://huggingface.co/datasets/alfraser/llm-arch-trace"
196
+ trace_dir = "trace"
197
+ trace_file_name = "trace.json"
198
+ trace_file = os.path.join(trace_dir, trace_file_name)
199
+ queue = Queue()
200
+ commit_time = 10 # Number of seconds after which to commit with no activity
201
+ commit_after = 10 # Number of records after which to commit irrespective of time
202
+ commit_count = 0 # Current uncommitted records
203
+ commit_timer = None # The actual commit timer - we will schedule the commit on this
204
+
205
+ def run(self):
206
+ while True:
207
+ request, trace, trace_tags, trace_comment = LogWorker.queue.get()
208
+ if request is None:
209
+ LogWorker.commit_repo()
210
+ else:
211
+ if LogWorker.commit_timer is not None and LogWorker.commit_timer.is_alive():
212
+ LogWorker.commit_timer.cancel()
213
+ LogWorker.commit_timer = None
214
+ try:
215
+ save_dict = {
216
+ 'architecture': self.name,
217
+ 'request': request.as_dict(),
218
+ 'trace': trace.as_dict(),
219
+ 'test_tags': trace_tags,
220
+ 'test_comment': trace_comment
221
+ }
222
+ LogWorker.append_and_save_data_as_json(save_dict)
223
+ LogWorker.commit_count += 1
224
+ if LogWorker.commit_count >= LogWorker.commit_after:
225
+ LogWorker.commit_repo()
226
+ except Exception as err:
227
+ print(f"Request / trace save failed {err}")
228
+
229
+ LogWorker.commit_timer = Timer(LogWorker.commit_time, LogWorker.signal_commit)
230
+ LogWorker.commit_timer.start()
231
+
232
+ @classmethod
233
+ def append_and_save_data_as_json(cls, data: Dict):
234
+ print(f"LogWorker logging open record {LogWorker.commit_count + 1}")
235
+ if cls.save_repo is None and not cls.save_repo_load_error:
236
+ try:
237
+ hf_write_token = hf_api_token(write=True)
238
+ cls.save_repo = Repository(local_dir=cls.trace_dir, clone_from=cls.save_repo_url, token=hf_write_token)
239
+ except Exception as err:
240
+ cls.save_repo_load_error = True
241
+ print(f"Error connecting to the save repo {err} - persistence now disabled")
242
+
243
+ if cls.save_repo is not None:
244
+ with open(cls.trace_file, 'r') as f:
245
+ test_json = json.load(f)
246
+ test_json['tests'].append(data)
247
+ with open(cls.trace_file, 'w') as f:
248
+ json.dump(test_json, f, indent=2)
249
+
250
+ @classmethod
251
+ def commit_repo(cls):
252
+ print(f"LogWorker committing {LogWorker.commit_count} open records")
253
+ cls.save_repo.push_to_hub()
254
+ LogWorker.commit_count = 0
255
+
256
+ @classmethod
257
+ def signal_commit(cls):
258
+ print("LogWorker signalling commit based on time elapsed")
259
+ cls.queue.put((None, None, None, None))
260
+
261
+ @classmethod
262
+ def write(cls, request: ArchitectureRequest, trace: ArchitectureTrace,
263
+ trace_tags: List[str] = None, trace_comment: str = None):
264
+ trace_tags = [] if trace_tags is None else trace_tags
265
+ trace_comment = "" if trace_comment is None else trace_comment
266
+ cls.queue.put((request, trace, trace_tags, trace_comment))
267
+
268
+
269
+ # Instantiate and run worker on import
270
+ if LogWorker.instance is None:
271
+ LogWorker.instance = LogWorker()
272
+ LogWorker.instance.start()
273
+
274
+
275
  class Architecture:
276
  """
277
  An architecture is built as a callable pipeline of steps. An
 
362
  return a
363
  raise ValueError(f"Could not find an architecture named {name}")
364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
  def __init__(self,
367
  name: str,
 
403
  trace.steps[-1].exception = err
404
  traceback.print_exc()
405
  break
406
+ LogWorker.write(request, trace, trace_tags, trace_comment)
407
  return trace
408
 
409
 
src/testing.py CHANGED
@@ -9,14 +9,19 @@ import sys
9
  from huggingface_hub import Repository
10
  from queue import Queue
11
  from random import choices
12
- from threading import Thread
13
  from typing import Dict, List, Optional, Tuple
14
 
15
- from src.architectures import Architecture, ArchitectureRequest
16
- from src.common import data_dir
17
 
18
 
19
  class ArchitectureTestWorker(Thread):
 
 
 
 
 
20
  def __init__(self, work_queue: Queue, worker_name: str, trace_tags: List[str], trace_comment: str):
21
  Thread.__init__(self)
22
  self.work_queue = work_queue
@@ -29,7 +34,7 @@ class ArchitectureTestWorker(Thread):
29
  while running:
30
  arch, request = self.work_queue.get()
31
  try:
32
- if arch is None:
33
  running = False
34
  else:
35
  print(f'{self.worker_name} running "{request.request}" through {arch}')
 
9
  from huggingface_hub import Repository
10
  from queue import Queue
11
  from random import choices
12
+ from threading import Thread, Timer
13
  from typing import Dict, List, Optional, Tuple
14
 
15
+ from src.architectures import Architecture, ArchitectureRequest, ArchitectureTrace
16
+ from src.common import data_dir, hf_api_token
17
 
18
 
19
  class ArchitectureTestWorker(Thread):
20
+ """
21
+ This class is worker which takes a test request off the queue and passes
22
+ it to an architecture for execution. Used to multi-thread the testing process
23
+ for speed as there is a tonne of i/o blocking waiting for the LLM
24
+ """
25
  def __init__(self, work_queue: Queue, worker_name: str, trace_tags: List[str], trace_comment: str):
26
  Thread.__init__(self)
27
  self.work_queue = work_queue
 
34
  while running:
35
  arch, request = self.work_queue.get()
36
  try:
37
+ if arch is None: # None passed to signal end of test requests
38
  running = False
39
  else:
40
  print(f'{self.worker_name} running "{request.request}" through {arch}')