import os import re import random import time import torch import torch.nn as nn import logging import numpy as np from os import path as osp def constant_init(module, val, bias=0): if hasattr(module, 'weight') and module.weight is not None: nn.init.constant_(module.weight, val) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) initialized_logger = {} def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): """Get the root logger. The logger will be initialized if it has not been initialized. By default a StreamHandler will be added. If `log_file` is specified, a FileHandler will also be added. Args: logger_name (str): root logger name. Default: 'basicsr'. log_file (str | None): The log filename. If specified, a FileHandler will be added to the root logger. log_level (int): The root logger level. Note that only the process of rank 0 is affected, while other processes will set the level to "Error" and be silent most of the time. Returns: logging.Logger: The root logger. """ logger = logging.getLogger(logger_name) # if the logger has been initialized, just return it if logger_name in initialized_logger: return logger format_str = '%(asctime)s %(levelname)s: %(message)s' stream_handler = logging.StreamHandler() stream_handler.setFormatter(logging.Formatter(format_str)) logger.addHandler(stream_handler) logger.propagate = False if log_file is not None: logger.setLevel(log_level) # add file handler # file_handler = logging.FileHandler(log_file, 'w') file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log file_handler.setFormatter(logging.Formatter(format_str)) file_handler.setLevel(log_level) logger.addHandler(file_handler) initialized_logger[logger_name] = True return logger IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ torch.__version__)[0][:3])] >= [1, 12, 0] def gpu_is_available(): if IS_HIGH_VERSION: if torch.backends.mps.is_available(): return True return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False def get_device(gpu_id=None): if gpu_id is None: gpu_str = '' elif isinstance(gpu_id, int): gpu_str = f':{gpu_id}' else: raise TypeError('Input should be int value.') if IS_HIGH_VERSION: if torch.backends.mps.is_available(): return torch.device('mps'+gpu_str) return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') def set_random_seed(seed): """Set random seeds.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) def get_time_str(): return time.strftime('%Y%m%d_%H%M%S', time.localtime()) def scandir(dir_path, suffix=None, recursive=False, full_path=False): """Scan a directory to find the interested files. Args: dir_path (str): Path of the directory. suffix (str | tuple(str), optional): File suffix that we are interested in. Default: None. recursive (bool, optional): If set to True, recursively scan the directory. Default: False. full_path (bool, optional): If set to True, include the dir_path. Default: False. Returns: A generator for all the interested files with relative pathes. """ if (suffix is not None) and not isinstance(suffix, (str, tuple)): raise TypeError('"suffix" must be a string or tuple of strings') root = dir_path def _scandir(dir_path, suffix, recursive): for entry in os.scandir(dir_path): if not entry.name.startswith('.') and entry.is_file(): if full_path: return_path = entry.path else: return_path = osp.relpath(entry.path, root) if suffix is None: yield return_path elif return_path.endswith(suffix): yield return_path else: if recursive: yield from _scandir(entry.path, suffix=suffix, recursive=recursive) else: continue return _scandir(dir_path, suffix=suffix, recursive=recursive)