Elron commited on
Commit
214d47a
·
1 Parent(s): 5c0e64f

Upload artifact.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. artifact.py +35 -18
artifact.py CHANGED
@@ -26,12 +26,21 @@ class Artifactories(object):
26
  def __next__(self):
27
  return next(self.artifactories)
28
 
29
- def register_atrifactory(self, artifactory):
30
  assert isinstance(artifactory, Artifactory), "Artifactory must be an instance of Artifactory"
31
  assert hasattr(artifactory, "__contains__"), "Artifactory must have __contains__ method"
32
  assert hasattr(artifactory, "__getitem__"), "Artifactory must have __getitem__ method"
33
  self.artifactories = [artifactory] + self.artifactories
34
 
 
 
 
 
 
 
 
 
 
35
 
36
  def map_values_in_place(object, mapper):
37
  if isinstance(object, dict):
@@ -55,12 +64,11 @@ def get_closest_artifact_type(type):
55
 
56
  class UnrecognizedArtifactType(ValueError):
57
  def __init__(self, type) -> None:
 
 
58
  closest_artifact_type = get_closest_artifact_type(type)
59
- message = (
60
- f"'{type}' is not a recognized value for 'type' parameter."
61
- "\n\n"
62
- f"Did you mean '{closest_artifact_type}'?"
63
- )
64
  super().__init__(message)
65
 
66
 
@@ -77,15 +85,15 @@ class Artifact(Dataclass):
77
 
78
  @classmethod
79
  def is_artifact_dict(cls, d):
80
- return isinstance(d, dict) and "type" in d and d["type"] in cls._class_register
81
 
82
  @classmethod
83
- def verify_is_artifact_dict(cls, d):
84
  if not isinstance(d, dict):
85
  raise ValueError(f"Artifact dict <{d}> must be of type 'dict', got '{type(d)}'.")
86
  if "type" not in d:
87
  raise MissingArtifactType(d)
88
- if d["type"] not in cls._class_register:
89
  raise UnrecognizedArtifactType(d["type"])
90
 
91
  @classmethod
@@ -103,7 +111,7 @@ class Artifact(Dataclass):
103
 
104
  snake_case_key = camel_to_snake_case(artifact_class.__name__)
105
 
106
- if snake_case_key in cls._class_register:
107
  assert (
108
  cls._class_register[snake_case_key] == artifact_class
109
  ), f"Artifact class name must be unique, {snake_case_key} already exists for {cls._class_register[snake_case_key]}"
@@ -114,6 +122,10 @@ class Artifact(Dataclass):
114
 
115
  return snake_case_key
116
 
 
 
 
 
117
  @classmethod
118
  def is_artifact_file(cls, path):
119
  if not os.path.exists(path) or not os.path.isfile(path):
@@ -122,6 +134,14 @@ class Artifact(Dataclass):
122
  d = json.load(f)
123
  return cls.is_artifact_dict(d)
124
 
 
 
 
 
 
 
 
 
125
  @classmethod
126
  def _recursive_load(cls, d):
127
  if isinstance(d, dict):
@@ -134,6 +154,7 @@ class Artifact(Dataclass):
134
  else:
135
  pass
136
  if cls.is_artifact_dict(d):
 
137
  instance = cls._class_register[d.pop("type")](**d)
138
  return instance
139
  else:
@@ -141,18 +162,14 @@ class Artifact(Dataclass):
141
 
142
  @classmethod
143
  def from_dict(cls, d):
144
- cls.verify_is_artifact_dict(d)
145
  return cls._recursive_load(d)
146
 
147
  @classmethod
148
  def load(cls, path):
149
- try:
150
- with open(path, "r") as f:
151
- d = json.load(f)
152
- assert "type" in d, "Saved artifact must have a type field"
153
- return cls._recursive_load(d)
154
- except Exception as e:
155
- raise Exception(f"{e}\n\nFailed to load artifact from {path} see above for more details.")
156
 
157
  def prepare(self):
158
  pass
 
26
  def __next__(self):
27
  return next(self.artifactories)
28
 
29
+ def register(self, artifactory):
30
  assert isinstance(artifactory, Artifactory), "Artifactory must be an instance of Artifactory"
31
  assert hasattr(artifactory, "__contains__"), "Artifactory must have __contains__ method"
32
  assert hasattr(artifactory, "__getitem__"), "Artifactory must have __getitem__ method"
33
  self.artifactories = [artifactory] + self.artifactories
34
 
35
+ def unregister(self, artifactory):
36
+ assert isinstance(artifactory, Artifactory), "Artifactory must be an instance of Artifactory"
37
+ assert hasattr(artifactory, "__contains__"), "Artifactory must have __contains__ method"
38
+ assert hasattr(artifactory, "__getitem__"), "Artifactory must have __getitem__ method"
39
+ self.artifactories.remove(artifactory)
40
+
41
+ def reset(self):
42
+ self.artifactories = []
43
+
44
 
45
  def map_values_in_place(object, mapper):
46
  if isinstance(object, dict):
 
64
 
65
  class UnrecognizedArtifactType(ValueError):
66
  def __init__(self, type) -> None:
67
+ maybe_class = "".join(word.capitalize() for word in type.split("_"))
68
+ message = f"'{type}' is not a recognized artifact 'type'. Make sure a the class defined this type (Probably called '{maybe_class}' or similar) is defined and/or imported anywhere in the code executed."
69
  closest_artifact_type = get_closest_artifact_type(type)
70
+ if closest_artifact_type is not None:
71
+ message += "\n\n" f"Did you mean '{closest_artifact_type}'?"
 
 
 
72
  super().__init__(message)
73
 
74
 
 
85
 
86
  @classmethod
87
  def is_artifact_dict(cls, d):
88
+ return isinstance(d, dict) and "type" in d
89
 
90
  @classmethod
91
+ def verify_artifact_dict(cls, d):
92
  if not isinstance(d, dict):
93
  raise ValueError(f"Artifact dict <{d}> must be of type 'dict', got '{type(d)}'.")
94
  if "type" not in d:
95
  raise MissingArtifactType(d)
96
+ if not cls.is_registered_type(d["type"]):
97
  raise UnrecognizedArtifactType(d["type"])
98
 
99
  @classmethod
 
111
 
112
  snake_case_key = camel_to_snake_case(artifact_class.__name__)
113
 
114
+ if cls.is_registered_type(snake_case_key):
115
  assert (
116
  cls._class_register[snake_case_key] == artifact_class
117
  ), f"Artifact class name must be unique, {snake_case_key} already exists for {cls._class_register[snake_case_key]}"
 
122
 
123
  return snake_case_key
124
 
125
+ def __init_subclass__(cls, **kwargs):
126
+ super().__init_subclass__(**kwargs)
127
+ cls.register_class(cls)
128
+
129
  @classmethod
130
  def is_artifact_file(cls, path):
131
  if not os.path.exists(path) or not os.path.isfile(path):
 
134
  d = json.load(f)
135
  return cls.is_artifact_dict(d)
136
 
137
+ @classmethod
138
+ def is_registered_type(cls, type: str):
139
+ return type in cls._class_register
140
+
141
+ @classmethod
142
+ def is_registered_class(cls, clz: object):
143
+ return clz in set(cls._class_register.values())
144
+
145
  @classmethod
146
  def _recursive_load(cls, d):
147
  if isinstance(d, dict):
 
154
  else:
155
  pass
156
  if cls.is_artifact_dict(d):
157
+ cls.verify_artifact_dict(d)
158
  instance = cls._class_register[d.pop("type")](**d)
159
  return instance
160
  else:
 
162
 
163
  @classmethod
164
  def from_dict(cls, d):
165
+ cls.verify_artifact_dict(d)
166
  return cls._recursive_load(d)
167
 
168
  @classmethod
169
  def load(cls, path):
170
+ with open(path, "r") as f:
171
+ d = json.load(f)
172
+ return cls.from_dict(d)
 
 
 
 
173
 
174
  def prepare(self):
175
  pass