Spaces:
Runtime error
Runtime error
Reviewed comments and type hints
Browse files- src/architectures.py +63 -18
src/architectures.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
"""
|
2 |
This file contains all the code which defines architectures and
|
3 |
-
architecture components.
|
|
|
|
|
4 |
"""
|
5 |
|
6 |
import chromadb
|
@@ -82,7 +84,7 @@ class ArchitectureTraceOutcome(Enum):
|
|
82 |
|
83 |
class ArchitectureTraceStep:
|
84 |
"""
|
85 |
-
Class to hold the details of a single
|
86 |
"""
|
87 |
def __init__(self, name: str):
|
88 |
self.name = name
|
@@ -165,6 +167,11 @@ class ArchitectureTrace:
|
|
165 |
|
166 |
|
167 |
class ArchitectureComponent(ABC):
|
|
|
|
|
|
|
|
|
|
|
168 |
description = "Components should override a description"
|
169 |
|
170 |
@abstractmethod
|
@@ -188,6 +195,17 @@ class ArchitectureComponent(ABC):
|
|
188 |
|
189 |
|
190 |
class LogWorker(Thread):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
instance = None
|
192 |
architectures = None
|
193 |
save_repo = None
|
@@ -207,6 +225,7 @@ class LogWorker(Thread):
|
|
207 |
while True:
|
208 |
arch_name, request, trace, trace_tags, trace_comment = LogWorker.queue.get()
|
209 |
if request is None:
|
|
|
210 |
for func in LogWorker.timeout_functions:
|
211 |
print(f"LogWorker commit running {func.__name__}")
|
212 |
try:
|
@@ -215,6 +234,7 @@ class LogWorker(Thread):
|
|
215 |
print(f"Timeout func {func.__name__} had error {e}")
|
216 |
else:
|
217 |
if LogWorker.commit_timer is not None and LogWorker.commit_timer.is_alive():
|
|
|
218 |
LogWorker.commit_timer.cancel()
|
219 |
LogWorker.commit_timer = None
|
220 |
try:
|
@@ -232,11 +252,16 @@ class LogWorker(Thread):
|
|
232 |
except Exception as err:
|
233 |
print(f"Request / trace save failed {err}")
|
234 |
|
|
|
235 |
LogWorker.commit_timer = Timer(LogWorker.commit_time, LogWorker.signal_commit)
|
236 |
LogWorker.commit_timer.start()
|
237 |
|
238 |
@classmethod
|
239 |
-
def append_and_save_data_as_json(cls, data: Dict):
|
|
|
|
|
|
|
|
|
240 |
print(f"LogWorker logging open record {LogWorker.commit_count + 1}")
|
241 |
if cls.save_repo is None and not cls.save_repo_load_error:
|
242 |
try:
|
@@ -255,6 +280,9 @@ class LogWorker(Thread):
|
|
255 |
|
256 |
@classmethod
|
257 |
def commit_repo(cls):
|
|
|
|
|
|
|
258 |
if cls.commit_count > 0:
|
259 |
print(f"LogWorker committing {LogWorker.commit_count} open records")
|
260 |
cls.save_repo.push_to_hub()
|
@@ -270,7 +298,11 @@ class LogWorker(Thread):
|
|
270 |
|
271 |
@classmethod
|
272 |
def write(cls, arch_name: str, request: ArchitectureRequest, trace: ArchitectureTrace,
|
273 |
-
trace_tags: List[str] = None, trace_comment: str = None):
|
|
|
|
|
|
|
|
|
274 |
trace_tags = [] if trace_tags is None else trace_tags
|
275 |
trace_comment = "" if trace_comment is None else trace_comment
|
276 |
cls.queue.put((arch_name, request, trace, trace_tags, trace_comment))
|
@@ -302,7 +334,10 @@ class Architecture:
|
|
302 |
trace_file = os.path.join(trace_dir, trace_file_name)
|
303 |
|
304 |
@classmethod
|
305 |
-
def wipe_trace(cls, hf_write_token:str = None):
|
|
|
|
|
|
|
306 |
if os.path.exists(cls.trace_dir):
|
307 |
shutil.rmtree(cls.trace_dir)
|
308 |
try:
|
@@ -319,6 +354,9 @@ class Architecture:
|
|
319 |
|
320 |
@classmethod
|
321 |
def get_trace_records(cls) -> List[Dict]:
|
|
|
|
|
|
|
322 |
if not os.path.isfile(cls.trace_file):
|
323 |
hf_write_token = hf_api_token(write=True)
|
324 |
try:
|
@@ -395,8 +433,8 @@ class Architecture:
|
|
395 |
in sequence, allowing them to amend the request or early exit the processing. Also captures
|
396 |
exceptions and generates the trace, plus saves the request/response and the trace to a store
|
397 |
for analysis.
|
398 |
-
:param request:
|
399 |
-
:return:
|
400 |
"""
|
401 |
print(f'{self.name} processing query "{request.request}"')
|
402 |
trace = ArchitectureTrace()
|
@@ -420,6 +458,10 @@ class Architecture:
|
|
420 |
|
421 |
|
422 |
class InputRequestScreener(ArchitectureComponent):
|
|
|
|
|
|
|
|
|
423 |
description = "Simplistic input screener for demonstration. Screens inputs for profanity."
|
424 |
|
425 |
def process_request(self, request: ArchitectureRequest) -> None:
|
@@ -430,6 +472,12 @@ class InputRequestScreener(ArchitectureComponent):
|
|
430 |
|
431 |
|
432 |
class OutputResponseScreener(ArchitectureComponent):
|
|
|
|
|
|
|
|
|
|
|
|
|
433 |
description = "Screens outputs for offensive responses."
|
434 |
|
435 |
def __init__(self):
|
@@ -459,6 +507,11 @@ class OutputResponseScreener(ArchitectureComponent):
|
|
459 |
|
460 |
|
461 |
class RetrievalAugmentor(ArchitectureComponent):
|
|
|
|
|
|
|
|
|
|
|
462 |
description = "Retrieves appropriate documents from the store and then augments the request."
|
463 |
|
464 |
def __init__(self, vector_store: str, doc_count: int = 5):
|
@@ -494,7 +547,7 @@ class RetrievalAugmentor(ArchitectureComponent):
|
|
494 |
|
495 |
class HFInferenceEndpoint(ArchitectureComponent):
|
496 |
"""
|
497 |
-
A concrete pipeline component which sends the
|
498 |
inference endpoint on HuggingFace
|
499 |
"""
|
500 |
def __init__(self, endpoint_url: str, model_name: str, system_prompt: str, max_new_tokens: int,
|
@@ -522,7 +575,8 @@ class HFInferenceEndpoint(ArchitectureComponent):
|
|
522 |
"""
|
523 |
Main processing method for this function. Calls the HTTP service for the model
|
524 |
by port if provided or attempting to lookup by name, and then adds this to the
|
525 |
-
response element of the request.
|
|
|
526 |
"""
|
527 |
headers = {
|
528 |
"Accept": "application/json",
|
@@ -583,12 +637,3 @@ class ResponseTrimmer(ArchitectureComponent):
|
|
583 |
|
584 |
def config_description(self) -> str:
|
585 |
return f"Regexes: {self.regex_display}"
|
586 |
-
|
587 |
-
|
588 |
-
if __name__ == "__main__":
|
589 |
-
req = ArchitectureRequest("Testing")
|
590 |
-
a = Architecture.get_architecture("1. Baseline LLM")
|
591 |
-
a(req)
|
592 |
-
print("Hold")
|
593 |
-
|
594 |
-
|
|
|
1 |
"""
|
2 |
This file contains all the code which defines architectures and
|
3 |
+
architecture components. An architecture is modelled a pipeline of ArchitectureComponents
|
4 |
+
through which an ArchitectureRequest flows. Architectures are configured in the file
|
5 |
+
config/architectures.json
|
6 |
"""
|
7 |
|
8 |
import chromadb
|
|
|
84 |
|
85 |
class ArchitectureTraceStep:
|
86 |
"""
|
87 |
+
Class to hold the trace details of a single step in an Architecture pipeline
|
88 |
"""
|
89 |
def __init__(self, name: str):
|
90 |
self.name = name
|
|
|
167 |
|
168 |
|
169 |
class ArchitectureComponent(ABC):
|
170 |
+
"""
|
171 |
+
This is the anbstract base class for all classes which want to be a concrete components available
|
172 |
+
to be configured into an Architecture pipeline. Specifies the elements which need to be implemented
|
173 |
+
to be a compliant architecture component.
|
174 |
+
"""
|
175 |
description = "Components should override a description"
|
176 |
|
177 |
@abstractmethod
|
|
|
195 |
|
196 |
|
197 |
class LogWorker(Thread):
|
198 |
+
"""
|
199 |
+
The LogWorker implements a daemon thread which runs in the background to write the results
|
200 |
+
of user queries through the system to a log file for analysis/reporting and offline saving.
|
201 |
+
The LogWorker provides two functions to the system. 1) it moves this I/O operation out of the
|
202 |
+
main architecture execution which allows for clearer understanding of the true performance of the
|
203 |
+
architectures themselves. 2) it is designed to be run as a single thread to provide controlled
|
204 |
+
shared access to a resource (the log file) with an in-memory queue for thread safety, which then
|
205 |
+
allows us to multi-thread the architecture invocation itself. In addition to the LogWorker provides
|
206 |
+
some basic batching capabilities for performance (e.g. batches up N requests before committing the IO
|
207 |
+
operation to the file, or commits open activity after a set period of inactivity)
|
208 |
+
"""
|
209 |
instance = None
|
210 |
architectures = None
|
211 |
save_repo = None
|
|
|
225 |
while True:
|
226 |
arch_name, request, trace, trace_tags, trace_comment = LogWorker.queue.get()
|
227 |
if request is None:
|
228 |
+
# There was a period of inactivity so run the timeout functions
|
229 |
for func in LogWorker.timeout_functions:
|
230 |
print(f"LogWorker commit running {func.__name__}")
|
231 |
try:
|
|
|
234 |
print(f"Timeout func {func.__name__} had error {e}")
|
235 |
else:
|
236 |
if LogWorker.commit_timer is not None and LogWorker.commit_timer.is_alive():
|
237 |
+
# Cancel the inactivity timer
|
238 |
LogWorker.commit_timer.cancel()
|
239 |
LogWorker.commit_timer = None
|
240 |
try:
|
|
|
252 |
except Exception as err:
|
253 |
print(f"Request / trace save failed {err}")
|
254 |
|
255 |
+
# Restart the inactivity timer
|
256 |
LogWorker.commit_timer = Timer(LogWorker.commit_time, LogWorker.signal_commit)
|
257 |
LogWorker.commit_timer.start()
|
258 |
|
259 |
@classmethod
|
260 |
+
def append_and_save_data_as_json(cls, data: Dict) -> None:
|
261 |
+
"""
|
262 |
+
If the working log file is not download, then get a local copy.
|
263 |
+
Add the new record to the local file.
|
264 |
+
"""
|
265 |
print(f"LogWorker logging open record {LogWorker.commit_count + 1}")
|
266 |
if cls.save_repo is None and not cls.save_repo_load_error:
|
267 |
try:
|
|
|
280 |
|
281 |
@classmethod
|
282 |
def commit_repo(cls):
|
283 |
+
"""
|
284 |
+
If there are any changes in the local file which are not committed to the repo then commit them.
|
285 |
+
"""
|
286 |
if cls.commit_count > 0:
|
287 |
print(f"LogWorker committing {LogWorker.commit_count} open records")
|
288 |
cls.save_repo.push_to_hub()
|
|
|
298 |
|
299 |
@classmethod
|
300 |
def write(cls, arch_name: str, request: ArchitectureRequest, trace: ArchitectureTrace,
|
301 |
+
trace_tags: List[str] = None, trace_comment: str = None) -> None:
|
302 |
+
"""
|
303 |
+
Class method callable from across the system to put a logging request onto the queue so that
|
304 |
+
the LogWorker will pick it up in turn and write it to the log
|
305 |
+
"""
|
306 |
trace_tags = [] if trace_tags is None else trace_tags
|
307 |
trace_comment = "" if trace_comment is None else trace_comment
|
308 |
cls.queue.put((arch_name, request, trace, trace_tags, trace_comment))
|
|
|
334 |
trace_file = os.path.join(trace_dir, trace_file_name)
|
335 |
|
336 |
@classmethod
|
337 |
+
def wipe_trace(cls, hf_write_token:str = None) -> None:
|
338 |
+
"""
|
339 |
+
Wipes the json trace file - note will not delete any records which have been saved offline to the database
|
340 |
+
"""
|
341 |
if os.path.exists(cls.trace_dir):
|
342 |
shutil.rmtree(cls.trace_dir)
|
343 |
try:
|
|
|
354 |
|
355 |
@classmethod
|
356 |
def get_trace_records(cls) -> List[Dict]:
|
357 |
+
"""
|
358 |
+
Loads and returns all the trace records which are held in the trace file
|
359 |
+
"""
|
360 |
if not os.path.isfile(cls.trace_file):
|
361 |
hf_write_token = hf_api_token(write=True)
|
362 |
try:
|
|
|
433 |
in sequence, allowing them to amend the request or early exit the processing. Also captures
|
434 |
exceptions and generates the trace, plus saves the request/response and the trace to a store
|
435 |
for analysis.
|
436 |
+
:param request: The architecture request to pass down the pipeline
|
437 |
+
:return: The trace record for this invocation of the architecture
|
438 |
"""
|
439 |
print(f'{self.name} processing query "{request.request}"')
|
440 |
trace = ArchitectureTrace()
|
|
|
458 |
|
459 |
|
460 |
class InputRequestScreener(ArchitectureComponent):
|
461 |
+
"""
|
462 |
+
This is a concrete component which screens the input query for profanity using an off the shelf
|
463 |
+
profanity search library (better_profanity)
|
464 |
+
"""
|
465 |
description = "Simplistic input screener for demonstration. Screens inputs for profanity."
|
466 |
|
467 |
def process_request(self, request: ArchitectureRequest) -> None:
|
|
|
472 |
|
473 |
|
474 |
class OutputResponseScreener(ArchitectureComponent):
|
475 |
+
"""
|
476 |
+
This is a concrete component designed to review the final response before showing it to the user.
|
477 |
+
It is a simple exemplar component using a call to the baseline LLM just with the response text and asking
|
478 |
+
the baseline LLM if it contains anything offensive. This is illustrative only and should not be considered
|
479 |
+
a best in class or production usable safety implementation.
|
480 |
+
"""
|
481 |
description = "Screens outputs for offensive responses."
|
482 |
|
483 |
def __init__(self):
|
|
|
507 |
|
508 |
|
509 |
class RetrievalAugmentor(ArchitectureComponent):
|
510 |
+
"""
|
511 |
+
This is a concrete implementation of the RAG augmentation component of the RAG architecture. Takes
|
512 |
+
the current input request, queries the vector store for documents and then appends these documents into
|
513 |
+
the beginning of the LLM prompt, ready for inference.
|
514 |
+
"""
|
515 |
description = "Retrieves appropriate documents from the store and then augments the request."
|
516 |
|
517 |
def __init__(self, vector_store: str, doc_count: int = 5):
|
|
|
547 |
|
548 |
class HFInferenceEndpoint(ArchitectureComponent):
|
549 |
"""
|
550 |
+
A concrete pipeline component which sends the current query to a given llama chat based
|
551 |
inference endpoint on HuggingFace
|
552 |
"""
|
553 |
def __init__(self, endpoint_url: str, model_name: str, system_prompt: str, max_new_tokens: int,
|
|
|
575 |
"""
|
576 |
Main processing method for this function. Calls the HTTP service for the model
|
577 |
by port if provided or attempting to lookup by name, and then adds this to the
|
578 |
+
response element of the request. Support different prompt styles that were tested
|
579 |
+
during testing to determine the best way to get a good response from the various LLM endpoints.
|
580 |
"""
|
581 |
headers = {
|
582 |
"Accept": "application/json",
|
|
|
637 |
|
638 |
def config_description(self) -> str:
|
639 |
return f"Regexes: {self.regex_display}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|