from collections import OrderedDict import torch import torch.distributed as dist from colossalai.zero.low_level.low_level_optim import LowLevelZeroOptimizer def get_model_numel(model: torch.nn.Module) -> int: return sum(p.numel() for p in model.parameters()) def format_numel_str(numel: int) -> str: B = 1024**3 M = 1024**2 K = 1024 if numel >= B: return f"{numel / B:.2f} B" elif numel >= M: return f"{numel / M:.2f} M" elif numel >= K: return f"{numel / K:.2f} K" else: return f"{numel}" def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) tensor.div_(dist.get_world_size()) return tensor @torch.no_grad() def update_ema( ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999, sharded: bool = True ) -> None: """ Step the EMA model towards the current model. """ ema_params = OrderedDict(ema_model.named_parameters()) model_params = OrderedDict(model.named_parameters()) for name, param in model_params.items(): if name == "pos_embed": continue if param.requires_grad == False: continue if not sharded: param_data = param.data ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay) else: if param.data.dtype != torch.float32 and isinstance(optimizer, LowLevelZeroOptimizer): param_id = id(param) master_param = optimizer._param_store.working_to_master_param[param_id] param_data = master_param.data else: param_data = param.data ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay) def requires_grad(model: torch.nn.Module, flag: bool = True) -> None: """ Set requires_grad flag for all parameters in a model. """ for p in model.parameters(): p.requires_grad = flag