Spaces:
Sleeping
Sleeping
from collections import OrderedDict | |
import torch | |
import torch.distributed as dist | |
from colossalai.zero.low_level.low_level_optim import LowLevelZeroOptimizer | |
def get_model_numel(model: torch.nn.Module) -> int: | |
return sum(p.numel() for p in model.parameters()) | |
def format_numel_str(numel: int) -> str: | |
B = 1024**3 | |
M = 1024**2 | |
K = 1024 | |
if numel >= B: | |
return f"{numel / B:.2f} B" | |
elif numel >= M: | |
return f"{numel / M:.2f} M" | |
elif numel >= K: | |
return f"{numel / K:.2f} K" | |
else: | |
return f"{numel}" | |
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: | |
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) | |
tensor.div_(dist.get_world_size()) | |
return tensor | |
def update_ema( | |
ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999, sharded: bool = True | |
) -> None: | |
""" | |
Step the EMA model towards the current model. | |
""" | |
ema_params = OrderedDict(ema_model.named_parameters()) | |
model_params = OrderedDict(model.named_parameters()) | |
for name, param in model_params.items(): | |
if name == "pos_embed": | |
continue | |
if param.requires_grad == False: | |
continue | |
if not sharded: | |
param_data = param.data | |
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay) | |
else: | |
if param.data.dtype != torch.float32 and isinstance(optimizer, LowLevelZeroOptimizer): | |
param_id = id(param) | |
master_param = optimizer._param_store.working_to_master_param[param_id] | |
param_data = master_param.data | |
else: | |
param_data = param.data | |
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay) | |
def requires_grad(model: torch.nn.Module, flag: bool = True) -> None: | |
""" | |
Set requires_grad flag for all parameters in a model. | |
""" | |
for p in model.parameters(): | |
p.requires_grad = flag | |