Elron commited on
Commit
ba3eb02
·
verified ·
1 Parent(s): 70deb6e

Upload artifact.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. artifact.py +19 -15
artifact.py CHANGED
@@ -1,7 +1,6 @@
1
  import difflib
2
  import inspect
3
  import json
4
- import logging
5
  import os
6
  import pkgutil
7
  from abc import abstractmethod
@@ -9,8 +8,12 @@ from copy import deepcopy
9
  from typing import Dict, List, Union, final
10
 
11
  from .dataclass import Dataclass, Field, fields
 
12
  from .text_utils import camel_to_snake_case, is_camel_case
13
  from .type_utils import issubtype
 
 
 
14
 
15
 
16
  class Artifactories:
@@ -121,17 +124,17 @@ class Artifact(Dataclass):
121
  def register_class(cls, artifact_class):
122
  assert issubclass(
123
  artifact_class, Artifact
124
- ), f"Artifact class must be a subclass of Artifact, got {artifact_class}"
125
  assert is_camel_case(
126
  artifact_class.__name__
127
- ), f"Artifact class name must be legal camel case, got {artifact_class.__name__}"
128
 
129
  snake_case_key = camel_to_snake_case(artifact_class.__name__)
130
 
131
  if cls.is_registered_type(snake_case_key):
132
  assert (
133
  cls._class_register[snake_case_key] == artifact_class
134
- ), f"Artifact class name must be unique, {snake_case_key} already exists for {cls._class_register[snake_case_key]}"
135
 
136
  return snake_case_key
137
 
@@ -155,6 +158,11 @@ class Artifact(Dataclass):
155
  def is_registered_type(cls, type: str):
156
  return type in cls._class_register
157
 
 
 
 
 
 
158
  @classmethod
159
  def is_registered_class(cls, clz: object):
160
  return clz in set(cls._class_register.values())
@@ -183,8 +191,7 @@ class Artifact(Dataclass):
183
 
184
  @classmethod
185
  def load(cls, path):
186
- with open(path) as f:
187
- d = json.load(f)
188
  return cls.from_dict(d)
189
 
190
  def prepare(self):
@@ -216,11 +223,8 @@ class Artifact(Dataclass):
216
  return {"type": self.type, **self._init_dict}
217
 
218
  def save(self, path):
219
- with open(path, "w") as f:
220
- init_dict = self.to_dict()
221
- dumped = json.dumps(init_dict, indent=4)
222
- f.write(dumped)
223
- f.write("\n")
224
 
225
 
226
  class ArtifactList(list, Artifact):
@@ -261,7 +265,7 @@ def fetch_artifact(name):
261
 
262
  def verbosed_fetch_artifact(identifer):
263
  artifact, artifactory = fetch_artifact(identifer)
264
- logging.info(f"Artifact {identifer} is fetched from {artifactory}")
265
  return artifact
266
 
267
 
@@ -274,10 +278,10 @@ def maybe_recover_artifact(artifact):
274
 
275
  def register_all_artifacts(path):
276
  for loader, module_name, _is_pkg in pkgutil.walk_packages(path):
277
- logging.info(__name__)
278
  if module_name == __name__:
279
  continue
280
- logging.info(f"Loading {module_name}")
281
  # Import the module
282
  module = loader.find_module(module_name).load_module(module_name)
283
 
@@ -287,4 +291,4 @@ def register_all_artifacts(path):
287
  if inspect.isclass(obj):
288
  # Make sure the class is a subclass of Artifact (but not Artifact itself)
289
  if issubclass(obj, Artifact) and obj is not Artifact:
290
- logging.info(obj)
 
1
  import difflib
2
  import inspect
3
  import json
 
4
  import os
5
  import pkgutil
6
  from abc import abstractmethod
 
8
  from typing import Dict, List, Union, final
9
 
10
  from .dataclass import Dataclass, Field, fields
11
+ from .logging_utils import get_logger
12
  from .text_utils import camel_to_snake_case, is_camel_case
13
  from .type_utils import issubtype
14
+ from .utils import load_json, save_json
15
+
16
+ logger = get_logger()
17
 
18
 
19
  class Artifactories:
 
124
  def register_class(cls, artifact_class):
125
  assert issubclass(
126
  artifact_class, Artifact
127
+ ), f"Artifact class must be a subclass of Artifact, got '{artifact_class}'"
128
  assert is_camel_case(
129
  artifact_class.__name__
130
+ ), f"Artifact class name must be legal camel case, got '{artifact_class.__name__}'"
131
 
132
  snake_case_key = camel_to_snake_case(artifact_class.__name__)
133
 
134
  if cls.is_registered_type(snake_case_key):
135
  assert (
136
  cls._class_register[snake_case_key] == artifact_class
137
+ ), f"Artifact class name must be unique, '{snake_case_key}' already exists for '{cls._class_register[snake_case_key]}'"
138
 
139
  return snake_case_key
140
 
 
158
  def is_registered_type(cls, type: str):
159
  return type in cls._class_register
160
 
161
+ @classmethod
162
+ def is_registered_class_name(cls, class_name: str):
163
+ snake_case_key = camel_to_snake_case(class_name)
164
+ return cls.is_registered_type(snake_case_key)
165
+
166
  @classmethod
167
  def is_registered_class(cls, clz: object):
168
  return clz in set(cls._class_register.values())
 
191
 
192
  @classmethod
193
  def load(cls, path):
194
+ d = load_json(path)
 
195
  return cls.from_dict(d)
196
 
197
  def prepare(self):
 
223
  return {"type": self.type, **self._init_dict}
224
 
225
  def save(self, path):
226
+ data = self.to_dict()
227
+ save_json(path, data)
 
 
 
228
 
229
 
230
  class ArtifactList(list, Artifact):
 
265
 
266
  def verbosed_fetch_artifact(identifer):
267
  artifact, artifactory = fetch_artifact(identifer)
268
+ logger.info(f"Artifact {identifer} is fetched from {artifactory}")
269
  return artifact
270
 
271
 
 
278
 
279
  def register_all_artifacts(path):
280
  for loader, module_name, _is_pkg in pkgutil.walk_packages(path):
281
+ logger.info(__name__)
282
  if module_name == __name__:
283
  continue
284
+ logger.info(f"Loading {module_name}")
285
  # Import the module
286
  module = loader.find_module(module_name).load_module(module_name)
287
 
 
291
  if inspect.isclass(obj):
292
  # Make sure the class is a subclass of Artifact (but not Artifact itself)
293
  if issubclass(obj, Artifact) and obj is not Artifact:
294
+ logger.info(obj)