File size: 1,473 Bytes
fa6856c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
from typing import List

from trlx.data.configs import TRLConfig


def _get_config_dirs(dir: str, config_dir_name: str = "configs") -> List[str]:
    """Returns all sub-directories of `dir` named `configs`."""
    config_dirs = []
    for root, dirs, _ in os.walk(dir):
        for d in dirs:
            if d == config_dir_name:
                config_dirs.append(os.path.join(root, d))
    return config_dirs


def _get_yaml_filepaths(dir: str) -> List[str]:
    """Returns a list of `yml` filepaths in `dir`."""
    filepaths = []
    for file in os.listdir(dir):
        if file.endswith(".yml"):
            filepaths.append(os.path.join(dir, file))
    return filepaths


def test_repo_trl_configs():
    """Tests to ensure all default configs in the repository are valid."""
    config_dirs = ["configs", *_get_config_dirs("examples")]
    config_files = sum(map(_get_yaml_filepaths, config_dirs), [])  # sum for flat-map behavior
    for file in config_files:
        assert os.path.isfile(file), f"Config file {file} does not exist."
        assert file.endswith(".yml"), f"Config file {file} is not a yaml file."
        try:
            config = TRLConfig.load_yaml(file)
            assert (
                config.train.entity_name is None
            ), f"Unexpected entity name in config file `{file}`. Remove before pushing to repo."
        except Exception as e:
            assert False, f"Failed to load config file `{file}` with error `{e}`"