alfraser commited on
Commit
8f8b146
·
1 Parent(s): 63018b5

Reviewed comments and type hints

Browse files
Files changed (1) hide show
  1. src/architectures.py +63 -18
src/architectures.py CHANGED
@@ -1,6 +1,8 @@
1
  """
2
  This file contains all the code which defines architectures and
3
- architecture components.
 
 
4
  """
5
 
6
  import chromadb
@@ -82,7 +84,7 @@ class ArchitectureTraceOutcome(Enum):
82
 
83
  class ArchitectureTraceStep:
84
  """
85
- Class to hold the details of a single trace step
86
  """
87
  def __init__(self, name: str):
88
  self.name = name
@@ -165,6 +167,11 @@ class ArchitectureTrace:
165
 
166
 
167
  class ArchitectureComponent(ABC):
 
 
 
 
 
168
  description = "Components should override a description"
169
 
170
  @abstractmethod
@@ -188,6 +195,17 @@ class ArchitectureComponent(ABC):
188
 
189
 
190
  class LogWorker(Thread):
 
 
 
 
 
 
 
 
 
 
 
191
  instance = None
192
  architectures = None
193
  save_repo = None
@@ -207,6 +225,7 @@ class LogWorker(Thread):
207
  while True:
208
  arch_name, request, trace, trace_tags, trace_comment = LogWorker.queue.get()
209
  if request is None:
 
210
  for func in LogWorker.timeout_functions:
211
  print(f"LogWorker commit running {func.__name__}")
212
  try:
@@ -215,6 +234,7 @@ class LogWorker(Thread):
215
  print(f"Timeout func {func.__name__} had error {e}")
216
  else:
217
  if LogWorker.commit_timer is not None and LogWorker.commit_timer.is_alive():
 
218
  LogWorker.commit_timer.cancel()
219
  LogWorker.commit_timer = None
220
  try:
@@ -232,11 +252,16 @@ class LogWorker(Thread):
232
  except Exception as err:
233
  print(f"Request / trace save failed {err}")
234
 
 
235
  LogWorker.commit_timer = Timer(LogWorker.commit_time, LogWorker.signal_commit)
236
  LogWorker.commit_timer.start()
237
 
238
  @classmethod
239
- def append_and_save_data_as_json(cls, data: Dict):
 
 
 
 
240
  print(f"LogWorker logging open record {LogWorker.commit_count + 1}")
241
  if cls.save_repo is None and not cls.save_repo_load_error:
242
  try:
@@ -255,6 +280,9 @@ class LogWorker(Thread):
255
 
256
  @classmethod
257
  def commit_repo(cls):
 
 
 
258
  if cls.commit_count > 0:
259
  print(f"LogWorker committing {LogWorker.commit_count} open records")
260
  cls.save_repo.push_to_hub()
@@ -270,7 +298,11 @@ class LogWorker(Thread):
270
 
271
  @classmethod
272
  def write(cls, arch_name: str, request: ArchitectureRequest, trace: ArchitectureTrace,
273
- trace_tags: List[str] = None, trace_comment: str = None):
 
 
 
 
274
  trace_tags = [] if trace_tags is None else trace_tags
275
  trace_comment = "" if trace_comment is None else trace_comment
276
  cls.queue.put((arch_name, request, trace, trace_tags, trace_comment))
@@ -302,7 +334,10 @@ class Architecture:
302
  trace_file = os.path.join(trace_dir, trace_file_name)
303
 
304
  @classmethod
305
- def wipe_trace(cls, hf_write_token:str = None):
 
 
 
306
  if os.path.exists(cls.trace_dir):
307
  shutil.rmtree(cls.trace_dir)
308
  try:
@@ -319,6 +354,9 @@ class Architecture:
319
 
320
  @classmethod
321
  def get_trace_records(cls) -> List[Dict]:
 
 
 
322
  if not os.path.isfile(cls.trace_file):
323
  hf_write_token = hf_api_token(write=True)
324
  try:
@@ -395,8 +433,8 @@ class Architecture:
395
  in sequence, allowing them to amend the request or early exit the processing. Also captures
396
  exceptions and generates the trace, plus saves the request/response and the trace to a store
397
  for analysis.
398
- :param request:
399
- :return:
400
  """
401
  print(f'{self.name} processing query "{request.request}"')
402
  trace = ArchitectureTrace()
@@ -420,6 +458,10 @@ class Architecture:
420
 
421
 
422
  class InputRequestScreener(ArchitectureComponent):
 
 
 
 
423
  description = "Simplistic input screener for demonstration. Screens inputs for profanity."
424
 
425
  def process_request(self, request: ArchitectureRequest) -> None:
@@ -430,6 +472,12 @@ class InputRequestScreener(ArchitectureComponent):
430
 
431
 
432
  class OutputResponseScreener(ArchitectureComponent):
 
 
 
 
 
 
433
  description = "Screens outputs for offensive responses."
434
 
435
  def __init__(self):
@@ -459,6 +507,11 @@ class OutputResponseScreener(ArchitectureComponent):
459
 
460
 
461
  class RetrievalAugmentor(ArchitectureComponent):
 
 
 
 
 
462
  description = "Retrieves appropriate documents from the store and then augments the request."
463
 
464
  def __init__(self, vector_store: str, doc_count: int = 5):
@@ -494,7 +547,7 @@ class RetrievalAugmentor(ArchitectureComponent):
494
 
495
  class HFInferenceEndpoint(ArchitectureComponent):
496
  """
497
- A concrete pipeline component which sends the user text to a given llama chat based
498
  inference endpoint on HuggingFace
499
  """
500
  def __init__(self, endpoint_url: str, model_name: str, system_prompt: str, max_new_tokens: int,
@@ -522,7 +575,8 @@ class HFInferenceEndpoint(ArchitectureComponent):
522
  """
523
  Main processing method for this function. Calls the HTTP service for the model
524
  by port if provided or attempting to lookup by name, and then adds this to the
525
- response element of the request.
 
526
  """
527
  headers = {
528
  "Accept": "application/json",
@@ -583,12 +637,3 @@ class ResponseTrimmer(ArchitectureComponent):
583
 
584
  def config_description(self) -> str:
585
  return f"Regexes: {self.regex_display}"
586
-
587
-
588
- if __name__ == "__main__":
589
- req = ArchitectureRequest("Testing")
590
- a = Architecture.get_architecture("1. Baseline LLM")
591
- a(req)
592
- print("Hold")
593
-
594
-
 
1
  """
2
  This file contains all the code which defines architectures and
3
+ architecture components. An architecture is modelled a pipeline of ArchitectureComponents
4
+ through which an ArchitectureRequest flows. Architectures are configured in the file
5
+ config/architectures.json
6
  """
7
 
8
  import chromadb
 
84
 
85
  class ArchitectureTraceStep:
86
  """
87
+ Class to hold the trace details of a single step in an Architecture pipeline
88
  """
89
  def __init__(self, name: str):
90
  self.name = name
 
167
 
168
 
169
  class ArchitectureComponent(ABC):
170
+ """
171
+ This is the anbstract base class for all classes which want to be a concrete components available
172
+ to be configured into an Architecture pipeline. Specifies the elements which need to be implemented
173
+ to be a compliant architecture component.
174
+ """
175
  description = "Components should override a description"
176
 
177
  @abstractmethod
 
195
 
196
 
197
  class LogWorker(Thread):
198
+ """
199
+ The LogWorker implements a daemon thread which runs in the background to write the results
200
+ of user queries through the system to a log file for analysis/reporting and offline saving.
201
+ The LogWorker provides two functions to the system. 1) it moves this I/O operation out of the
202
+ main architecture execution which allows for clearer understanding of the true performance of the
203
+ architectures themselves. 2) it is designed to be run as a single thread to provide controlled
204
+ shared access to a resource (the log file) with an in-memory queue for thread safety, which then
205
+ allows us to multi-thread the architecture invocation itself. In addition to the LogWorker provides
206
+ some basic batching capabilities for performance (e.g. batches up N requests before committing the IO
207
+ operation to the file, or commits open activity after a set period of inactivity)
208
+ """
209
  instance = None
210
  architectures = None
211
  save_repo = None
 
225
  while True:
226
  arch_name, request, trace, trace_tags, trace_comment = LogWorker.queue.get()
227
  if request is None:
228
+ # There was a period of inactivity so run the timeout functions
229
  for func in LogWorker.timeout_functions:
230
  print(f"LogWorker commit running {func.__name__}")
231
  try:
 
234
  print(f"Timeout func {func.__name__} had error {e}")
235
  else:
236
  if LogWorker.commit_timer is not None and LogWorker.commit_timer.is_alive():
237
+ # Cancel the inactivity timer
238
  LogWorker.commit_timer.cancel()
239
  LogWorker.commit_timer = None
240
  try:
 
252
  except Exception as err:
253
  print(f"Request / trace save failed {err}")
254
 
255
+ # Restart the inactivity timer
256
  LogWorker.commit_timer = Timer(LogWorker.commit_time, LogWorker.signal_commit)
257
  LogWorker.commit_timer.start()
258
 
259
  @classmethod
260
+ def append_and_save_data_as_json(cls, data: Dict) -> None:
261
+ """
262
+ If the working log file is not download, then get a local copy.
263
+ Add the new record to the local file.
264
+ """
265
  print(f"LogWorker logging open record {LogWorker.commit_count + 1}")
266
  if cls.save_repo is None and not cls.save_repo_load_error:
267
  try:
 
280
 
281
  @classmethod
282
  def commit_repo(cls):
283
+ """
284
+ If there are any changes in the local file which are not committed to the repo then commit them.
285
+ """
286
  if cls.commit_count > 0:
287
  print(f"LogWorker committing {LogWorker.commit_count} open records")
288
  cls.save_repo.push_to_hub()
 
298
 
299
  @classmethod
300
  def write(cls, arch_name: str, request: ArchitectureRequest, trace: ArchitectureTrace,
301
+ trace_tags: List[str] = None, trace_comment: str = None) -> None:
302
+ """
303
+ Class method callable from across the system to put a logging request onto the queue so that
304
+ the LogWorker will pick it up in turn and write it to the log
305
+ """
306
  trace_tags = [] if trace_tags is None else trace_tags
307
  trace_comment = "" if trace_comment is None else trace_comment
308
  cls.queue.put((arch_name, request, trace, trace_tags, trace_comment))
 
334
  trace_file = os.path.join(trace_dir, trace_file_name)
335
 
336
  @classmethod
337
+ def wipe_trace(cls, hf_write_token:str = None) -> None:
338
+ """
339
+ Wipes the json trace file - note will not delete any records which have been saved offline to the database
340
+ """
341
  if os.path.exists(cls.trace_dir):
342
  shutil.rmtree(cls.trace_dir)
343
  try:
 
354
 
355
  @classmethod
356
  def get_trace_records(cls) -> List[Dict]:
357
+ """
358
+ Loads and returns all the trace records which are held in the trace file
359
+ """
360
  if not os.path.isfile(cls.trace_file):
361
  hf_write_token = hf_api_token(write=True)
362
  try:
 
433
  in sequence, allowing them to amend the request or early exit the processing. Also captures
434
  exceptions and generates the trace, plus saves the request/response and the trace to a store
435
  for analysis.
436
+ :param request: The architecture request to pass down the pipeline
437
+ :return: The trace record for this invocation of the architecture
438
  """
439
  print(f'{self.name} processing query "{request.request}"')
440
  trace = ArchitectureTrace()
 
458
 
459
 
460
  class InputRequestScreener(ArchitectureComponent):
461
+ """
462
+ This is a concrete component which screens the input query for profanity using an off the shelf
463
+ profanity search library (better_profanity)
464
+ """
465
  description = "Simplistic input screener for demonstration. Screens inputs for profanity."
466
 
467
  def process_request(self, request: ArchitectureRequest) -> None:
 
472
 
473
 
474
  class OutputResponseScreener(ArchitectureComponent):
475
+ """
476
+ This is a concrete component designed to review the final response before showing it to the user.
477
+ It is a simple exemplar component using a call to the baseline LLM just with the response text and asking
478
+ the baseline LLM if it contains anything offensive. This is illustrative only and should not be considered
479
+ a best in class or production usable safety implementation.
480
+ """
481
  description = "Screens outputs for offensive responses."
482
 
483
  def __init__(self):
 
507
 
508
 
509
  class RetrievalAugmentor(ArchitectureComponent):
510
+ """
511
+ This is a concrete implementation of the RAG augmentation component of the RAG architecture. Takes
512
+ the current input request, queries the vector store for documents and then appends these documents into
513
+ the beginning of the LLM prompt, ready for inference.
514
+ """
515
  description = "Retrieves appropriate documents from the store and then augments the request."
516
 
517
  def __init__(self, vector_store: str, doc_count: int = 5):
 
547
 
548
  class HFInferenceEndpoint(ArchitectureComponent):
549
  """
550
+ A concrete pipeline component which sends the current query to a given llama chat based
551
  inference endpoint on HuggingFace
552
  """
553
  def __init__(self, endpoint_url: str, model_name: str, system_prompt: str, max_new_tokens: int,
 
575
  """
576
  Main processing method for this function. Calls the HTTP service for the model
577
  by port if provided or attempting to lookup by name, and then adds this to the
578
+ response element of the request. Support different prompt styles that were tested
579
+ during testing to determine the best way to get a good response from the various LLM endpoints.
580
  """
581
  headers = {
582
  "Accept": "application/json",
 
637
 
638
  def config_description(self) -> str:
639
  return f"Regexes: {self.regex_display}"