DDT / src /callbacks /simple_ema.py
wangshuai6
init space
9e426da
from typing import Any, Dict
import torch
import torch.nn as nn
import threading
import lightning.pytorch as pl
from lightning.pytorch import Callback
from lightning.pytorch.utilities.types import STEP_OUTPUT
from src.utils.copy import swap_tensors
class SimpleEMA(Callback):
def __init__(self, net:nn.Module, ema_net:nn.Module,
decay: float = 0.9999,
every_n_steps: int = 1,
eval_original_model:bool = False
):
super().__init__()
self.decay = decay
self.every_n_steps = every_n_steps
self.eval_original_model = eval_original_model
self._stream = torch.cuda.Stream()
self.net_params = list(net.parameters())
self.ema_params = list(ema_net.parameters())
def swap_model(self):
for ema_p, p, in zip(self.ema_params, self.net_params):
swap_tensors(ema_p, p)
def ema_step(self):
@torch.no_grad()
def ema_update(ema_model_tuple, current_model_tuple, decay):
torch._foreach_mul_(ema_model_tuple, decay)
torch._foreach_add_(
ema_model_tuple, current_model_tuple, alpha=(1.0 - decay),
)
if self._stream is not None:
self._stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._stream):
ema_update(self.ema_params, self.net_params, self.decay)
def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
if trainer.global_step % self.every_n_steps == 0:
self.ema_step()
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self.eval_original_model:
self.swap_model()
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self.eval_original_model:
self.swap_model()
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self.eval_original_model:
self.swap_model()
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self.eval_original_model:
self.swap_model()
def state_dict(self) -> Dict[str, Any]:
return {
"decay": self.decay,
"every_n_steps": self.every_n_steps,
"eval_original_model": self.eval_original_model,
}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.decay = state_dict["decay"]
self.every_n_steps = state_dict["every_n_steps"]
self.eval_original_model = state_dict["eval_original_model"]