import logging import os import time import git import omegaconf logger = logging.getLogger(__name__) def load_config(config_path, command_line_args=None): """Loads the config file using OmegaConf, performing merges with base configs and the command line arguments""" logger.info(f"Loading from: {config_path}") # Load the config using OmegaConf config = omegaconf.OmegaConf.load(config_path) # Load all the base configs to include, and merge them with the current config, giving precedence to the current config if hasattr(config, "include"): base_config_paths = [os.path.join(os.path.dirname(config_path), include_path) for include_path in config.include] base_configs = [load_config(base_config_path) for base_config_path in base_config_paths] config = omegaconf.OmegaConf.merge(*base_configs, config) # Load the command line arguments, and merge them with the current config, giving precedence to the command line if command_line_args is not None: command_line_config = omegaconf.OmegaConf.from_dotlist(command_line_args) config = omegaconf.OmegaConf.merge(config, command_line_config) return config def save_git_commit_info(save_path): """Use gitpython to save info about the current git commit to a file""" repo = git.Repo(search_parent_directories=True) head_commit = repo.head.commit git_commit_info = { "hexsha": head_commit.hexsha, "authored": { "author": head_commit.author.name, "authored_time": head_commit.authored_date, }, "committed": { "commit": head_commit.committer.name, "committed_time": head_commit.committed_date, }, "message": head_commit.message.strip(), } git_commit_info = omegaconf.OmegaConf.create(git_commit_info) omegaconf.OmegaConf.save(git_commit_info, save_path) return git_commit_info def create_workspace(config): """Create a results folder in the target directory""" # Treat the name as a time.strftime format string (so that every experiment is named after when it was run) config.name = time.strftime(config.name, time.localtime()) # Create the results directory os.makedirs(config.save_dir) # Save the config to the results directory omegaconf.OmegaConf.save(config, os.path.join(config.save_dir, "config.yaml")) save_git_commit_info(os.path.join(config.save_dir, "git.yaml")) # Set up the print loggers by removing all handlers associated with the root logger object, # then setting up the logger to print messages *and* save them to a file for handler in logging.root.handlers: logging.root.removeHandler(handler) logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[ logging.FileHandler(os.path.join(config.save_dir, "output.log")), logging.StreamHandler(), ], ) return config