Spaces:
Sleeping
Sleeping
| from typing import Any, Union | |
| import ignite.distributed as idist | |
| import torch | |
| from ignite.engine import DeterministicEngine, Engine, Events | |
| from torch.cuda.amp import autocast | |
| from torch.nn import Module | |
| from torch.optim import Optimizer | |
| from torch.utils.data import DistributedSampler, Sampler | |
| def setup_trainer( | |
| config: Any, | |
| model: Module, | |
| optimizer: Optimizer, | |
| loss_fn: Module, | |
| device: Union[str, torch.device], | |
| train_sampler: Sampler, | |
| ) -> Union[Engine, DeterministicEngine]: | |
| def train_function(engine: Union[Engine, DeterministicEngine], batch: Any): | |
| if config.overfit: | |
| # No batch norm | |
| model.eval() | |
| else: | |
| model.train() | |
| samples = batch[0].to(device, non_blocking=True) | |
| targets = batch[1].to(device, non_blocking=True) | |
| attack_targets = batch[2].to(device, non_blocking=True) | |
| sample_ids = batch[3].to(device, non_blocking=True) | |
| with autocast(config.use_amp): | |
| outputs = model(samples, attack_targets) | |
| loss = loss_fn(outputs, attack_targets, targets) | |
| loss.backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| train_loss = loss.item() | |
| engine.state.metrics = { | |
| "epoch": engine.state.epoch, | |
| "train_loss": train_loss, | |
| } | |
| return {"train_loss": train_loss} | |
| trainer = Engine(train_function) | |
| # set epoch for distributed sa5mpler | |
| def set_epoch(): | |
| if idist.get_world_size() > 1 and isinstance(train_sampler, DistributedSampler): | |
| train_sampler.set_epoch(trainer.state.epoch - 1) | |
| return trainer | |
| def setup_evaluator( | |
| config: Any, | |
| model: Module, | |
| device: Union[str, torch.device], | |
| ) -> Engine: | |
| def eval_function(engine: Engine, batch: Any): | |
| model.eval() | |
| samples, gt_labels, attack_targets, sample_ids = batch | |
| samples = samples.to(device, non_blocking=True) | |
| gt_labels = gt_labels.to(device, non_blocking=True) | |
| attack_targets = attack_targets.to(device, non_blocking=True) | |
| sample_ids = sample_ids.to(device, non_blocking=True) | |
| with autocast(config.use_amp): | |
| outputs, perturbations = model(samples, attack_targets, gt_labels) | |
| return outputs, attack_targets, { | |
| "gt_targets": gt_labels, | |
| "perturbations": perturbations | |
| } | |
| return Engine(eval_function) | |