alfraser commited on
Commit
745c1f4
·
1 Parent(s): b927d45

Added saving of the trace data

Browse files
Files changed (1) hide show
  1. src/architectures.py +69 -4
src/architectures.py CHANGED
@@ -12,8 +12,9 @@ import traceback
12
 
13
  from abc import ABC, abstractmethod
14
  from enum import Enum
 
15
  from time import time
16
- from typing import List, Optional
17
  from better_profanity import profanity
18
 
19
  from src.common import config_dir, data_dir, hf_api_token, escape_dollars
@@ -62,6 +63,9 @@ class ArchitectureRequest:
62
  md += f"\n - {r}"
63
  return escape_dollars(md)
64
 
 
 
 
65
 
66
  class ArchitectureTraceOutcome(Enum):
67
  """
@@ -115,6 +119,16 @@ class ArchitectureTraceStep:
115
  md += f" - **Outcome**: {outcome}"
116
  return escape_dollars(md)
117
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  class ArchitectureTrace:
120
  """
@@ -143,6 +157,9 @@ class ArchitectureTrace:
143
  md = ' \n'.join([s.as_markdown() for s in self.steps])
144
  return md
145
 
 
 
 
146
 
147
  class ArchitectureComponent(ABC):
148
  description = "Components should override a description"
@@ -177,6 +194,8 @@ class Architecture:
177
  crash the system.
178
  """
179
  architectures = None
 
 
180
 
181
  @classmethod
182
  def load_architectures(cls, force_reload: bool = False) -> None:
@@ -221,6 +240,46 @@ class Architecture:
221
  return a
222
  raise ValueError(f"Could not find an architecture named {name}")
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  def __init__(self,
225
  name: str,
226
  description: str,
@@ -235,7 +294,7 @@ class Architecture:
235
  self.exception_text = exception_text
236
  self.no_response_text = no_response_text
237
 
238
- def __call__(self, request: ArchitectureRequest) -> ArchitectureTrace:
239
  """
240
  The main entry point to call the pipeline. Passes the request through each pipeline step
241
  in sequence, allowing them to amend the request or early exit the processing. Also captures
@@ -260,8 +319,7 @@ class Architecture:
260
  trace.steps[-1].exception = err
261
  traceback.print_exc()
262
  break
263
- # TODO - save the request / response
264
- # TODO - save the trace
265
  return trace
266
 
267
 
@@ -414,3 +472,10 @@ class ResponseTrimmer(ArchitectureComponent):
414
  return f"Regexes: {self.regex_display}"
415
 
416
 
 
 
 
 
 
 
 
 
12
 
13
  from abc import ABC, abstractmethod
14
  from enum import Enum
15
+ from huggingface_hub import Repository
16
  from time import time
17
+ from typing import List, Optional, Dict
18
  from better_profanity import profanity
19
 
20
  from src.common import config_dir, data_dir, hf_api_token, escape_dollars
 
63
  md += f"\n - {r}"
64
  return escape_dollars(md)
65
 
66
+ def as_dict(self) -> Dict:
67
+ return {'request_evolution': self._request, 'response_evolution': self._response}
68
+
69
 
70
  class ArchitectureTraceOutcome(Enum):
71
  """
 
119
  md += f" - **Outcome**: {outcome}"
120
  return escape_dollars(md)
121
 
122
+ def as_dict(self) -> Dict:
123
+ return {
124
+ 'name': self.name,
125
+ 'start_ms': self.start_ms,
126
+ 'end_ms': self.end_ms,
127
+ 'outcome': str(self.outcome),
128
+ 'exception': "" if self._exception is None else f"{self._exception}",
129
+ 'early_exit_message': "" if self.early_exit_message is None else self.early_exit_message
130
+ }
131
+
132
 
133
  class ArchitectureTrace:
134
  """
 
157
  md = ' \n'.join([s.as_markdown() for s in self.steps])
158
  return md
159
 
160
+ def as_dict(self) -> Dict:
161
+ return {'steps': [s.as_dict() for s in self.steps]}
162
+
163
 
164
  class ArchitectureComponent(ABC):
165
  description = "Components should override a description"
 
194
  crash the system.
195
  """
196
  architectures = None
197
+ save_repo = None
198
+ save_repo_load_error = False
199
 
200
  @classmethod
201
  def load_architectures(cls, force_reload: bool = False) -> None:
 
240
  return a
241
  raise ValueError(f"Could not find an architecture named {name}")
242
 
243
+ def append_and_save_data_as_json(self, data: Dict):
244
+ repo_url = "https://huggingface.co/datasets/alfraser/llm-arch-trace"
245
+ trace_dir = "trace"
246
+ trace_file_name = "trace.json"
247
+ data_file = os.path.join(trace_dir, trace_file_name)
248
+ if Architecture.save_repo is None and not Architecture.save_repo_load_error:
249
+ try:
250
+ hf_write_token = hf_api_token(write=True)
251
+ Architecture.save_repo = Repository(local_dir=trace_dir, clone_from=repo_url, token=hf_write_token)
252
+ except Exception as err:
253
+ Architecture.save_repo_load_error = True
254
+ print(f"Error connecting to the save repo {err} - persistence now disabled")
255
+
256
+ if Architecture.save_repo is not None:
257
+ with open(data_file, 'r') as f:
258
+ test_json = json.load(f)
259
+ test_json['tests'].append(data)
260
+ with open(data_file, 'w') as f:
261
+ json.dump(test_json, f, indent=2)
262
+ Architecture.save_repo.push_to_hub()
263
+
264
+ def attempt_request_and_trace_save(self, request: ArchitectureRequest, trace: ArchitectureTrace,
265
+ trace_tags: List[str] = None) -> None:
266
+ """
267
+ Attempt to save a request and trace pair to a json store on huggingface datasets
268
+ Catch any errors and simply print as non-fatal to functional flow
269
+ """
270
+ try:
271
+ if trace_tags is None:
272
+ trace_tags = []
273
+ save_dict = {
274
+ 'architecture': self.name,
275
+ 'request': request.as_dict(),
276
+ 'trace': trace.as_dict(),
277
+ 'trace_tags': trace_tags
278
+ }
279
+ self.append_and_save_data_as_json(save_dict)
280
+ except Exception as err:
281
+ print(f"Request / trace save failed {err}")
282
+
283
  def __init__(self,
284
  name: str,
285
  description: str,
 
294
  self.exception_text = exception_text
295
  self.no_response_text = no_response_text
296
 
297
+ def __call__(self, request: ArchitectureRequest, trace_tags: List[str] = None) -> ArchitectureTrace:
298
  """
299
  The main entry point to call the pipeline. Passes the request through each pipeline step
300
  in sequence, allowing them to amend the request or early exit the processing. Also captures
 
319
  trace.steps[-1].exception = err
320
  traceback.print_exc()
321
  break
322
+ self.attempt_request_and_trace_save(request, trace, trace_tags)
 
323
  return trace
324
 
325
 
 
472
  return f"Regexes: {self.regex_display}"
473
 
474
 
475
+ if __name__ == "__main__":
476
+ req = ArchitectureRequest("Testing")
477
+ a = Architecture.get_architecture("1. Baseline LLM")
478
+ a(req)
479
+ print("Hold")
480
+
481
+