Spaces:
Runtime error
Runtime error
Reviewed comments and type hints
Browse files- src/architectures.py +63 -18
src/architectures.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
"""
|
| 2 |
This file contains all the code which defines architectures and
|
| 3 |
-
architecture components.
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import chromadb
|
|
@@ -82,7 +84,7 @@ class ArchitectureTraceOutcome(Enum):
|
|
| 82 |
|
| 83 |
class ArchitectureTraceStep:
|
| 84 |
"""
|
| 85 |
-
Class to hold the details of a single
|
| 86 |
"""
|
| 87 |
def __init__(self, name: str):
|
| 88 |
self.name = name
|
|
@@ -165,6 +167,11 @@ class ArchitectureTrace:
|
|
| 165 |
|
| 166 |
|
| 167 |
class ArchitectureComponent(ABC):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
description = "Components should override a description"
|
| 169 |
|
| 170 |
@abstractmethod
|
|
@@ -188,6 +195,17 @@ class ArchitectureComponent(ABC):
|
|
| 188 |
|
| 189 |
|
| 190 |
class LogWorker(Thread):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
instance = None
|
| 192 |
architectures = None
|
| 193 |
save_repo = None
|
|
@@ -207,6 +225,7 @@ class LogWorker(Thread):
|
|
| 207 |
while True:
|
| 208 |
arch_name, request, trace, trace_tags, trace_comment = LogWorker.queue.get()
|
| 209 |
if request is None:
|
|
|
|
| 210 |
for func in LogWorker.timeout_functions:
|
| 211 |
print(f"LogWorker commit running {func.__name__}")
|
| 212 |
try:
|
|
@@ -215,6 +234,7 @@ class LogWorker(Thread):
|
|
| 215 |
print(f"Timeout func {func.__name__} had error {e}")
|
| 216 |
else:
|
| 217 |
if LogWorker.commit_timer is not None and LogWorker.commit_timer.is_alive():
|
|
|
|
| 218 |
LogWorker.commit_timer.cancel()
|
| 219 |
LogWorker.commit_timer = None
|
| 220 |
try:
|
|
@@ -232,11 +252,16 @@ class LogWorker(Thread):
|
|
| 232 |
except Exception as err:
|
| 233 |
print(f"Request / trace save failed {err}")
|
| 234 |
|
|
|
|
| 235 |
LogWorker.commit_timer = Timer(LogWorker.commit_time, LogWorker.signal_commit)
|
| 236 |
LogWorker.commit_timer.start()
|
| 237 |
|
| 238 |
@classmethod
|
| 239 |
-
def append_and_save_data_as_json(cls, data: Dict):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
print(f"LogWorker logging open record {LogWorker.commit_count + 1}")
|
| 241 |
if cls.save_repo is None and not cls.save_repo_load_error:
|
| 242 |
try:
|
|
@@ -255,6 +280,9 @@ class LogWorker(Thread):
|
|
| 255 |
|
| 256 |
@classmethod
|
| 257 |
def commit_repo(cls):
|
|
|
|
|
|
|
|
|
|
| 258 |
if cls.commit_count > 0:
|
| 259 |
print(f"LogWorker committing {LogWorker.commit_count} open records")
|
| 260 |
cls.save_repo.push_to_hub()
|
|
@@ -270,7 +298,11 @@ class LogWorker(Thread):
|
|
| 270 |
|
| 271 |
@classmethod
|
| 272 |
def write(cls, arch_name: str, request: ArchitectureRequest, trace: ArchitectureTrace,
|
| 273 |
-
trace_tags: List[str] = None, trace_comment: str = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
trace_tags = [] if trace_tags is None else trace_tags
|
| 275 |
trace_comment = "" if trace_comment is None else trace_comment
|
| 276 |
cls.queue.put((arch_name, request, trace, trace_tags, trace_comment))
|
|
@@ -302,7 +334,10 @@ class Architecture:
|
|
| 302 |
trace_file = os.path.join(trace_dir, trace_file_name)
|
| 303 |
|
| 304 |
@classmethod
|
| 305 |
-
def wipe_trace(cls, hf_write_token:str = None):
|
|
|
|
|
|
|
|
|
|
| 306 |
if os.path.exists(cls.trace_dir):
|
| 307 |
shutil.rmtree(cls.trace_dir)
|
| 308 |
try:
|
|
@@ -319,6 +354,9 @@ class Architecture:
|
|
| 319 |
|
| 320 |
@classmethod
|
| 321 |
def get_trace_records(cls) -> List[Dict]:
|
|
|
|
|
|
|
|
|
|
| 322 |
if not os.path.isfile(cls.trace_file):
|
| 323 |
hf_write_token = hf_api_token(write=True)
|
| 324 |
try:
|
|
@@ -395,8 +433,8 @@ class Architecture:
|
|
| 395 |
in sequence, allowing them to amend the request or early exit the processing. Also captures
|
| 396 |
exceptions and generates the trace, plus saves the request/response and the trace to a store
|
| 397 |
for analysis.
|
| 398 |
-
:param request:
|
| 399 |
-
:return:
|
| 400 |
"""
|
| 401 |
print(f'{self.name} processing query "{request.request}"')
|
| 402 |
trace = ArchitectureTrace()
|
|
@@ -420,6 +458,10 @@ class Architecture:
|
|
| 420 |
|
| 421 |
|
| 422 |
class InputRequestScreener(ArchitectureComponent):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
description = "Simplistic input screener for demonstration. Screens inputs for profanity."
|
| 424 |
|
| 425 |
def process_request(self, request: ArchitectureRequest) -> None:
|
|
@@ -430,6 +472,12 @@ class InputRequestScreener(ArchitectureComponent):
|
|
| 430 |
|
| 431 |
|
| 432 |
class OutputResponseScreener(ArchitectureComponent):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
description = "Screens outputs for offensive responses."
|
| 434 |
|
| 435 |
def __init__(self):
|
|
@@ -459,6 +507,11 @@ class OutputResponseScreener(ArchitectureComponent):
|
|
| 459 |
|
| 460 |
|
| 461 |
class RetrievalAugmentor(ArchitectureComponent):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
description = "Retrieves appropriate documents from the store and then augments the request."
|
| 463 |
|
| 464 |
def __init__(self, vector_store: str, doc_count: int = 5):
|
|
@@ -494,7 +547,7 @@ class RetrievalAugmentor(ArchitectureComponent):
|
|
| 494 |
|
| 495 |
class HFInferenceEndpoint(ArchitectureComponent):
|
| 496 |
"""
|
| 497 |
-
A concrete pipeline component which sends the
|
| 498 |
inference endpoint on HuggingFace
|
| 499 |
"""
|
| 500 |
def __init__(self, endpoint_url: str, model_name: str, system_prompt: str, max_new_tokens: int,
|
|
@@ -522,7 +575,8 @@ class HFInferenceEndpoint(ArchitectureComponent):
|
|
| 522 |
"""
|
| 523 |
Main processing method for this function. Calls the HTTP service for the model
|
| 524 |
by port if provided or attempting to lookup by name, and then adds this to the
|
| 525 |
-
response element of the request.
|
|
|
|
| 526 |
"""
|
| 527 |
headers = {
|
| 528 |
"Accept": "application/json",
|
|
@@ -583,12 +637,3 @@ class ResponseTrimmer(ArchitectureComponent):
|
|
| 583 |
|
| 584 |
def config_description(self) -> str:
|
| 585 |
return f"Regexes: {self.regex_display}"
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
if __name__ == "__main__":
|
| 589 |
-
req = ArchitectureRequest("Testing")
|
| 590 |
-
a = Architecture.get_architecture("1. Baseline LLM")
|
| 591 |
-
a(req)
|
| 592 |
-
print("Hold")
|
| 593 |
-
|
| 594 |
-
|
|
|
|
| 1 |
"""
|
| 2 |
This file contains all the code which defines architectures and
|
| 3 |
+
architecture components. An architecture is modelled a pipeline of ArchitectureComponents
|
| 4 |
+
through which an ArchitectureRequest flows. Architectures are configured in the file
|
| 5 |
+
config/architectures.json
|
| 6 |
"""
|
| 7 |
|
| 8 |
import chromadb
|
|
|
|
| 84 |
|
| 85 |
class ArchitectureTraceStep:
|
| 86 |
"""
|
| 87 |
+
Class to hold the trace details of a single step in an Architecture pipeline
|
| 88 |
"""
|
| 89 |
def __init__(self, name: str):
|
| 90 |
self.name = name
|
|
|
|
| 167 |
|
| 168 |
|
| 169 |
class ArchitectureComponent(ABC):
|
| 170 |
+
"""
|
| 171 |
+
This is the anbstract base class for all classes which want to be a concrete components available
|
| 172 |
+
to be configured into an Architecture pipeline. Specifies the elements which need to be implemented
|
| 173 |
+
to be a compliant architecture component.
|
| 174 |
+
"""
|
| 175 |
description = "Components should override a description"
|
| 176 |
|
| 177 |
@abstractmethod
|
|
|
|
| 195 |
|
| 196 |
|
| 197 |
class LogWorker(Thread):
|
| 198 |
+
"""
|
| 199 |
+
The LogWorker implements a daemon thread which runs in the background to write the results
|
| 200 |
+
of user queries through the system to a log file for analysis/reporting and offline saving.
|
| 201 |
+
The LogWorker provides two functions to the system. 1) it moves this I/O operation out of the
|
| 202 |
+
main architecture execution which allows for clearer understanding of the true performance of the
|
| 203 |
+
architectures themselves. 2) it is designed to be run as a single thread to provide controlled
|
| 204 |
+
shared access to a resource (the log file) with an in-memory queue for thread safety, which then
|
| 205 |
+
allows us to multi-thread the architecture invocation itself. In addition to the LogWorker provides
|
| 206 |
+
some basic batching capabilities for performance (e.g. batches up N requests before committing the IO
|
| 207 |
+
operation to the file, or commits open activity after a set period of inactivity)
|
| 208 |
+
"""
|
| 209 |
instance = None
|
| 210 |
architectures = None
|
| 211 |
save_repo = None
|
|
|
|
| 225 |
while True:
|
| 226 |
arch_name, request, trace, trace_tags, trace_comment = LogWorker.queue.get()
|
| 227 |
if request is None:
|
| 228 |
+
# There was a period of inactivity so run the timeout functions
|
| 229 |
for func in LogWorker.timeout_functions:
|
| 230 |
print(f"LogWorker commit running {func.__name__}")
|
| 231 |
try:
|
|
|
|
| 234 |
print(f"Timeout func {func.__name__} had error {e}")
|
| 235 |
else:
|
| 236 |
if LogWorker.commit_timer is not None and LogWorker.commit_timer.is_alive():
|
| 237 |
+
# Cancel the inactivity timer
|
| 238 |
LogWorker.commit_timer.cancel()
|
| 239 |
LogWorker.commit_timer = None
|
| 240 |
try:
|
|
|
|
| 252 |
except Exception as err:
|
| 253 |
print(f"Request / trace save failed {err}")
|
| 254 |
|
| 255 |
+
# Restart the inactivity timer
|
| 256 |
LogWorker.commit_timer = Timer(LogWorker.commit_time, LogWorker.signal_commit)
|
| 257 |
LogWorker.commit_timer.start()
|
| 258 |
|
| 259 |
@classmethod
|
| 260 |
+
def append_and_save_data_as_json(cls, data: Dict) -> None:
|
| 261 |
+
"""
|
| 262 |
+
If the working log file is not download, then get a local copy.
|
| 263 |
+
Add the new record to the local file.
|
| 264 |
+
"""
|
| 265 |
print(f"LogWorker logging open record {LogWorker.commit_count + 1}")
|
| 266 |
if cls.save_repo is None and not cls.save_repo_load_error:
|
| 267 |
try:
|
|
|
|
| 280 |
|
| 281 |
@classmethod
|
| 282 |
def commit_repo(cls):
|
| 283 |
+
"""
|
| 284 |
+
If there are any changes in the local file which are not committed to the repo then commit them.
|
| 285 |
+
"""
|
| 286 |
if cls.commit_count > 0:
|
| 287 |
print(f"LogWorker committing {LogWorker.commit_count} open records")
|
| 288 |
cls.save_repo.push_to_hub()
|
|
|
|
| 298 |
|
| 299 |
@classmethod
|
| 300 |
def write(cls, arch_name: str, request: ArchitectureRequest, trace: ArchitectureTrace,
|
| 301 |
+
trace_tags: List[str] = None, trace_comment: str = None) -> None:
|
| 302 |
+
"""
|
| 303 |
+
Class method callable from across the system to put a logging request onto the queue so that
|
| 304 |
+
the LogWorker will pick it up in turn and write it to the log
|
| 305 |
+
"""
|
| 306 |
trace_tags = [] if trace_tags is None else trace_tags
|
| 307 |
trace_comment = "" if trace_comment is None else trace_comment
|
| 308 |
cls.queue.put((arch_name, request, trace, trace_tags, trace_comment))
|
|
|
|
| 334 |
trace_file = os.path.join(trace_dir, trace_file_name)
|
| 335 |
|
| 336 |
@classmethod
|
| 337 |
+
def wipe_trace(cls, hf_write_token:str = None) -> None:
|
| 338 |
+
"""
|
| 339 |
+
Wipes the json trace file - note will not delete any records which have been saved offline to the database
|
| 340 |
+
"""
|
| 341 |
if os.path.exists(cls.trace_dir):
|
| 342 |
shutil.rmtree(cls.trace_dir)
|
| 343 |
try:
|
|
|
|
| 354 |
|
| 355 |
@classmethod
|
| 356 |
def get_trace_records(cls) -> List[Dict]:
|
| 357 |
+
"""
|
| 358 |
+
Loads and returns all the trace records which are held in the trace file
|
| 359 |
+
"""
|
| 360 |
if not os.path.isfile(cls.trace_file):
|
| 361 |
hf_write_token = hf_api_token(write=True)
|
| 362 |
try:
|
|
|
|
| 433 |
in sequence, allowing them to amend the request or early exit the processing. Also captures
|
| 434 |
exceptions and generates the trace, plus saves the request/response and the trace to a store
|
| 435 |
for analysis.
|
| 436 |
+
:param request: The architecture request to pass down the pipeline
|
| 437 |
+
:return: The trace record for this invocation of the architecture
|
| 438 |
"""
|
| 439 |
print(f'{self.name} processing query "{request.request}"')
|
| 440 |
trace = ArchitectureTrace()
|
|
|
|
| 458 |
|
| 459 |
|
| 460 |
class InputRequestScreener(ArchitectureComponent):
|
| 461 |
+
"""
|
| 462 |
+
This is a concrete component which screens the input query for profanity using an off the shelf
|
| 463 |
+
profanity search library (better_profanity)
|
| 464 |
+
"""
|
| 465 |
description = "Simplistic input screener for demonstration. Screens inputs for profanity."
|
| 466 |
|
| 467 |
def process_request(self, request: ArchitectureRequest) -> None:
|
|
|
|
| 472 |
|
| 473 |
|
| 474 |
class OutputResponseScreener(ArchitectureComponent):
|
| 475 |
+
"""
|
| 476 |
+
This is a concrete component designed to review the final response before showing it to the user.
|
| 477 |
+
It is a simple exemplar component using a call to the baseline LLM just with the response text and asking
|
| 478 |
+
the baseline LLM if it contains anything offensive. This is illustrative only and should not be considered
|
| 479 |
+
a best in class or production usable safety implementation.
|
| 480 |
+
"""
|
| 481 |
description = "Screens outputs for offensive responses."
|
| 482 |
|
| 483 |
def __init__(self):
|
|
|
|
| 507 |
|
| 508 |
|
| 509 |
class RetrievalAugmentor(ArchitectureComponent):
|
| 510 |
+
"""
|
| 511 |
+
This is a concrete implementation of the RAG augmentation component of the RAG architecture. Takes
|
| 512 |
+
the current input request, queries the vector store for documents and then appends these documents into
|
| 513 |
+
the beginning of the LLM prompt, ready for inference.
|
| 514 |
+
"""
|
| 515 |
description = "Retrieves appropriate documents from the store and then augments the request."
|
| 516 |
|
| 517 |
def __init__(self, vector_store: str, doc_count: int = 5):
|
|
|
|
| 547 |
|
| 548 |
class HFInferenceEndpoint(ArchitectureComponent):
|
| 549 |
"""
|
| 550 |
+
A concrete pipeline component which sends the current query to a given llama chat based
|
| 551 |
inference endpoint on HuggingFace
|
| 552 |
"""
|
| 553 |
def __init__(self, endpoint_url: str, model_name: str, system_prompt: str, max_new_tokens: int,
|
|
|
|
| 575 |
"""
|
| 576 |
Main processing method for this function. Calls the HTTP service for the model
|
| 577 |
by port if provided or attempting to lookup by name, and then adds this to the
|
| 578 |
+
response element of the request. Support different prompt styles that were tested
|
| 579 |
+
during testing to determine the best way to get a good response from the various LLM endpoints.
|
| 580 |
"""
|
| 581 |
headers = {
|
| 582 |
"Accept": "application/json",
|
|
|
|
| 637 |
|
| 638 |
def config_description(self) -> str:
|
| 639 |
return f"Regexes: {self.regex_display}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|