Elron commited on
Commit
534245a
·
1 Parent(s): aff1020

Upload dataclass.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dataclass.py +39 -11
dataclass.py CHANGED
@@ -235,11 +235,7 @@ def asdict(obj):
235
 
236
  def _asdict_inner(obj):
237
  if is_dataclass(obj):
238
- result = {}
239
- for field in fields(obj):
240
- v = getattr(obj, field.name)
241
- result[field.name] = _asdict_inner(v)
242
- return result
243
  elif isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple
244
  return type(obj)(*[_asdict_inner(v) for v in obj])
245
  elif isinstance(obj, (list, tuple)):
@@ -340,16 +336,36 @@ class Dataclass(metaclass=DataclassMeta):
340
  if name in kwargs:
341
  raise TypeError(f"{self.__class__.__name__} got multiple values for argument '{name}'")
342
 
 
 
343
  if len(argv) <= len(_init_positional_fields_names):
344
  unexpected_argv = []
345
  else:
346
  unexpected_argv = argv[len(_init_positional_fields_names) :]
347
 
348
- unexpected_kwargs = {k: v for k, v in kwargs.items() if k not in _init_fields_names}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
 
350
  if self.__allow_unexpected_arguments__:
351
- self._argv = unexpected_argv
352
- self._kwargs = unexpected_kwargs
 
 
353
 
354
  else:
355
  if len(unexpected_argv) > 0:
@@ -376,12 +392,12 @@ class Dataclass(metaclass=DataclassMeta):
376
  f"Required field '{field.name}' of class {field.origin_cls} not set in {self.__class__.__name__}"
377
  )
378
 
 
 
379
  for field in fields(self):
380
  if field.name in kwargs:
381
  setattr(self, field.name, kwargs[field.name])
382
  else:
383
- if field.name in ["_argv", "_kwargs"] and self.__allow_unexpected_arguments__:
384
- continue
385
  setattr(self, field.name, get_field_default(field))
386
 
387
  self.__post_init__()
@@ -390,17 +406,29 @@ class Dataclass(metaclass=DataclassMeta):
390
  def __is_dataclass__(self) -> bool:
391
  return True
392
 
 
 
 
 
 
 
393
  def __post_init__(self):
394
  """
395
  Post initialization hook.
396
  """
397
  pass
398
 
 
 
 
 
 
 
399
  def to_dict(self):
400
  """
401
  Convert to dict.
402
  """
403
- return asdict(self)
404
 
405
  def __repr__(self) -> str:
406
  """
 
235
 
236
  def _asdict_inner(obj):
237
  if is_dataclass(obj):
238
+ return obj.to_dict()
 
 
 
 
239
  elif isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple
240
  return type(obj)(*[_asdict_inner(v) for v in obj])
241
  elif isinstance(obj, (list, tuple)):
 
336
  if name in kwargs:
337
  raise TypeError(f"{self.__class__.__name__} got multiple values for argument '{name}'")
338
 
339
+ expected_unexpected_argv = kwargs.pop("_argv", None)
340
+
341
  if len(argv) <= len(_init_positional_fields_names):
342
  unexpected_argv = []
343
  else:
344
  unexpected_argv = argv[len(_init_positional_fields_names) :]
345
 
346
+ if expected_unexpected_argv is not None:
347
+ assert (
348
+ len(unexpected_argv) == 0
349
+ ), f"Cannot specify both _argv and unexpected positional arguments. Got {unexpected_argv}"
350
+ unexpected_argv = tuple(expected_unexpected_argv)
351
+
352
+ expected_unexpected_kwargs = kwargs.pop("_kwargs", None)
353
+ unexpected_kwargs = {
354
+ k: v for k, v in kwargs.items() if k not in _init_fields_names and k not in ["_argv", "_kwargs"]
355
+ }
356
+
357
+ if expected_unexpected_kwargs is not None:
358
+ intersection = set(unexpected_kwargs.keys()) & set(expected_unexpected_kwargs.keys())
359
+ assert (
360
+ len(intersection) == 0
361
+ ), f"Cannot specify the same arguments in both _kwargs and in unexpected keyword arguments. Got {intersection} in both."
362
+ unexpected_kwargs = {**unexpected_kwargs, **expected_unexpected_kwargs}
363
 
364
  if self.__allow_unexpected_arguments__:
365
+ if len(unexpected_argv) > 0:
366
+ kwargs["_argv"] = unexpected_argv
367
+ if len(unexpected_kwargs) > 0:
368
+ kwargs["_kwargs"] = unexpected_kwargs
369
 
370
  else:
371
  if len(unexpected_argv) > 0:
 
392
  f"Required field '{field.name}' of class {field.origin_cls} not set in {self.__class__.__name__}"
393
  )
394
 
395
+ self.__pre_init__(**kwargs)
396
+
397
  for field in fields(self):
398
  if field.name in kwargs:
399
  setattr(self, field.name, kwargs[field.name])
400
  else:
 
 
401
  setattr(self, field.name, get_field_default(field))
402
 
403
  self.__post_init__()
 
406
  def __is_dataclass__(self) -> bool:
407
  return True
408
 
409
+ def __pre_init__(self, **kwargs):
410
+ """
411
+ Pre initialization hook.
412
+ """
413
+ pass
414
+
415
  def __post_init__(self):
416
  """
417
  Post initialization hook.
418
  """
419
  pass
420
 
421
+ def _to_raw_dict(self):
422
+ """
423
+ Convert to raw dict
424
+ """
425
+ return {field.name: getattr(self, field.name) for field in fields(self)}
426
+
427
  def to_dict(self):
428
  """
429
  Convert to dict.
430
  """
431
+ return _asdict_inner(self._to_raw_dict())
432
 
433
  def __repr__(self) -> str:
434
  """