alfraser commited on
Commit
8f424fc
·
1 Parent(s): 3e88400

Refactored to bring common variables together. Also added a utility to get all the trace records as a list of records

Browse files
Files changed (1) hide show
  1. src/architectures.py +32 -23
src/architectures.py CHANGED
@@ -197,27 +197,39 @@ class Architecture:
197
  architectures = None
198
  save_repo = None
199
  save_repo_load_error = False
 
 
 
 
200
 
201
  @classmethod
202
  def wipe_trace(cls):
203
- repo_url = "https://huggingface.co/datasets/alfraser/llm-arch-trace"
204
- trace_dir = "trace"
205
- trace_file_name = "trace.json"
206
- data_file = os.path.join(trace_dir, trace_file_name)
207
- if os.path.exists(trace_dir):
208
- shutil.rmtree(trace_dir)
209
-
210
  try:
211
  hf_write_token = hf_api_token(write=True)
212
- Architecture.save_repo = Repository(local_dir=trace_dir, clone_from=repo_url, token=hf_write_token)
213
  test_json = {'tests': []}
214
- with open(data_file, 'w') as f:
215
  json.dump(test_json, f, indent=2)
216
- Architecture.save_repo.push_to_hub()
217
  except Exception as err:
218
- Architecture.save_repo_load_error = True
219
  print(f"Error connecting to the save repo {err} - persistence now disabled")
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  @classmethod
223
  def load_architectures(cls, force_reload: bool = False) -> None:
@@ -262,26 +274,23 @@ class Architecture:
262
  return a
263
  raise ValueError(f"Could not find an architecture named {name}")
264
 
265
- def append_and_save_data_as_json(self, data: Dict):
266
- repo_url = "https://huggingface.co/datasets/alfraser/llm-arch-trace"
267
- trace_dir = "trace"
268
- trace_file_name = "trace.json"
269
- data_file = os.path.join(trace_dir, trace_file_name)
270
- if Architecture.save_repo is None and not Architecture.save_repo_load_error:
271
  try:
272
  hf_write_token = hf_api_token(write=True)
273
- Architecture.save_repo = Repository(local_dir=trace_dir, clone_from=repo_url, token=hf_write_token)
274
  except Exception as err:
275
- Architecture.save_repo_load_error = True
276
  print(f"Error connecting to the save repo {err} - persistence now disabled")
277
 
278
- if Architecture.save_repo is not None:
279
- with open(data_file, 'r') as f:
280
  test_json = json.load(f)
281
  test_json['tests'].append(data)
282
- with open(data_file, 'w') as f:
283
  json.dump(test_json, f, indent=2)
284
- Architecture.save_repo.push_to_hub()
285
 
286
  def attempt_request_and_trace_save(self, request: ArchitectureRequest, trace: ArchitectureTrace,
287
  trace_tags: List[str] = None, trace_comment: str = None) -> None:
 
197
  architectures = None
198
  save_repo = None
199
  save_repo_load_error = False
200
+ save_repo_url = "https://huggingface.co/datasets/alfraser/llm-arch-trace"
201
+ trace_dir = "trace"
202
+ trace_file_name = "trace.json"
203
+ trace_file = os.path.join(trace_dir, trace_file_name)
204
 
205
  @classmethod
206
  def wipe_trace(cls):
207
+ if os.path.exists(cls.trace_dir):
208
+ shutil.rmtree(cls.trace_dir)
 
 
 
 
 
209
  try:
210
  hf_write_token = hf_api_token(write=True)
211
+ cls.save_repo = Repository(local_dir=cls.trace_dir, clone_from=cls.save_repo_url, token=hf_write_token)
212
  test_json = {'tests': []}
213
+ with open(cls.trace_file, 'w') as f:
214
  json.dump(test_json, f, indent=2)
215
+ cls.save_repo.push_to_hub()
216
  except Exception as err:
217
+ cls.save_repo_load_error = True
218
  print(f"Error connecting to the save repo {err} - persistence now disabled")
219
 
220
+ @classmethod
221
+ def get_trace_records(cls) -> List[Dict]:
222
+ if not os.path.isfile(cls.trace_file):
223
+ hf_write_token = hf_api_token(write=True)
224
+ try:
225
+ cls.save_repo = Repository(local_dir=cls.trace_dir, clone_from=cls.save_repo_url, token=hf_write_token)
226
+ except Exception as err:
227
+ cls.save_repo_load_error = True
228
+ print(f"Error connecting to the save repo {err} - persistence now disabled")
229
+ return []
230
+ with open(cls.trace_file, 'r') as f:
231
+ test_json = json.load(f)
232
+ return test_json['tests']
233
 
234
  @classmethod
235
  def load_architectures(cls, force_reload: bool = False) -> None:
 
274
  return a
275
  raise ValueError(f"Could not find an architecture named {name}")
276
 
277
+ @classmethod
278
+ def append_and_save_data_as_json(cls, data: Dict):
279
+ if cls.save_repo is None and not cls.save_repo_load_error:
 
 
 
280
  try:
281
  hf_write_token = hf_api_token(write=True)
282
+ cls.save_repo = Repository(local_dir=cls.trace_dir, clone_from=cls.repo_url, token=hf_write_token)
283
  except Exception as err:
284
+ cls.save_repo_load_error = True
285
  print(f"Error connecting to the save repo {err} - persistence now disabled")
286
 
287
+ if cls.save_repo is not None:
288
+ with open(cls.data_file, 'r') as f:
289
  test_json = json.load(f)
290
  test_json['tests'].append(data)
291
+ with open(cls.data_file, 'w') as f:
292
  json.dump(test_json, f, indent=2)
293
+ cls.save_repo.push_to_hub()
294
 
295
  def attempt_request_and_trace_save(self, request: ArchitectureRequest, trace: ArchitectureTrace,
296
  trace_tags: List[str] = None, trace_comment: str = None) -> None: