|
import dataclasses
|
|
import re
|
|
import copy
|
|
import yaml
|
|
from pathlib import Path
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Iterable, List, NewType, Optional, Tuple, Union, Dict
|
|
|
|
from transformers.hf_argparser import DataClass, HfArgumentParser as OriginalHfArgumentParser
|
|
|
|
DataClass = NewType("DataClass", Any)
|
|
DataClassType = NewType("DataClassType", Any)
|
|
|
|
def lambda_field(default, **kwargs):
|
|
return field(default_factory=lambda: copy.copy(default))
|
|
|
|
class HfArgumentParser(OriginalHfArgumentParser):
|
|
def parse_yaml_file(self, yaml_file: str) -> Tuple[DataClass, ...]:
|
|
"""
|
|
Parse a YAML file and return a tuple of dataclass instances.
|
|
|
|
Args:
|
|
yaml_file (str): Path to the YAML file.
|
|
|
|
Returns:
|
|
Tuple[DataClass, ...]: A tuple of dataclass instances.
|
|
"""
|
|
|
|
loader = yaml.SafeLoader
|
|
loader.add_implicit_resolver(
|
|
u'tag:yaml.org,2002:float',
|
|
re.compile(u'''^(?:
|
|
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
|
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
|
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
|
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
|
|[-+]?\\.(?:inf|Inf|INF)
|
|
|\\.(?:nan|NaN|NAN))$''', re.X),
|
|
list(u'-+0123456789.')
|
|
)
|
|
|
|
|
|
data = yaml.load(Path(yaml_file).read_text(), Loader=loader)
|
|
|
|
|
|
outputs = []
|
|
|
|
|
|
for dtype in self.dataclass_types:
|
|
|
|
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
|
|
|
arg_name = dtype.__mro__[-2].__name__
|
|
|
|
inputs = {k: v for k, v in data[arg_name].items() if k in keys}
|
|
|
|
obj = dtype(**inputs)
|
|
|
|
outputs.append(obj)
|
|
|
|
|
|
return (*outputs,)
|
|
|