File size: 31,355 Bytes
9d0d223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
import typing as tp

import flashy
import omegaconf
import torch
from torch import nn

from .. import optim
from ..optim import fsdp
from ..utils import checkpoint
from ..utils.autocast import TorchAutocast
from ..utils.best_state import BestStateDictManager
from ..utils.deadlock import DeadlockDetect
from ..utils.profiler import Profiler
from ..utils.utils import copy_state, dict_from_config, model_hash, with_rank_rng


class StandardSolver(ABC, flashy.BaseSolver):
    """Standard solver for AudioCraft.

    The standard solver implements a base training loop with the following stages:
    train, valid, evaluate and generate that are expected to be all defined for
    solvers in AudioCraft. It also provides a nice default management of Dora history replay,
    checkpoint management across epoch, and logging configuration.

    AudioCraft solvers must inherit from the StandardSolver and define the methods
    associated to each stage as well as the show, build_model and build_dataloaders methods.
    """
    def __init__(self, cfg: omegaconf.DictConfig):
        super().__init__()
        self.logger.info(f"Instantiating solver {self.__class__.__name__} for XP {self.xp.sig}")
        self.logger.info(f"All XP logs are stored in {self.xp.folder}")
        self.cfg = cfg
        self.device = cfg.device
        self.model: nn.Module
        self._continue_best_source_keys = ['best_state', 'fsdp_best_state']
        self._fsdp_modules: tp.List[fsdp.FSDP] = []
        self._ema_sources: nn.ModuleDict = nn.ModuleDict()
        self.ema: tp.Optional[optim.ModuleDictEMA] = None
        self.dataloaders: tp.Dict[str, torch.utils.data.DataLoader] = dict()
        self._log_updates = self.cfg.logging.get('log_updates', 10)
        if self.cfg.logging.log_tensorboard:
            self.init_tensorboard(**self.cfg.get('tensorboard'))
        if self.cfg.logging.log_wandb and self:
            self.init_wandb(**self.cfg.get('wandb'))
        # keep a copy of the best performing state for stateful objects
        # used for evaluation and generation stages
        dtype_best: tp.Optional[torch.dtype] = None
        if self.cfg.fsdp.use:
            dtype_best = getattr(torch, self.cfg.fsdp.param_dtype)  # type: ignore
            assert isinstance(dtype_best, torch.dtype)
        elif self.cfg.autocast:
            dtype_best = getattr(torch, self.cfg.autocast_dtype)  # type: ignore
            assert isinstance(dtype_best, torch.dtype)
        self.best_state: BestStateDictManager = BestStateDictManager(dtype=dtype_best)
        # Hacky support for keeping a copy of the full best state in rank0.
        self.fsdp_best_state: tp.Dict[str, tp.Any] = {}
        self.register_stateful('best_state', 'fsdp_best_state')  # register best_state object to keep it in state_dict
        self._new_best_state: bool = False  # should save a new checkpoint
        # instantiate datasets and appropriate number of updates per epoch
        self.build_dataloaders()
        if self.cfg.execute_only is None:
            assert 'train' in self.dataloaders, "The train dataset split must be provided."
            assert 'valid' in self.dataloaders, "The valid dataset split must be provided."
        self.train_updates_per_epoch = len(self.dataloaders['train']) if 'train' in self.dataloaders else 0
        if self.cfg.optim.updates_per_epoch:
            self.train_updates_per_epoch = self.cfg.optim.updates_per_epoch
        self.total_updates = self.train_updates_per_epoch * self.cfg.optim.epochs
        # instantiate model & exponential moving average on the model
        self.build_model()
        self.logger.info("Model hash: %s", model_hash(self.model))
        assert 'model' in self.stateful.sources, \
            "Please register the model to stateful with self.register_stateful('model') in build_model."
        self.profiler = Profiler(self.model, **self.cfg.profiler)
        self.initialize_ema()
        self.register_stateful('ema')
        assert self.ema is None or 'ema' in self.stateful.sources, \
            "Please register the ema to stateful with self.register_stateful('ema') in build_model."
        self.deadlock_detect = DeadlockDetect(**self.cfg.deadlock)
        # basic statistics on the trained model
        model_size = sum(p.numel() for p in self.model.parameters() if p.requires_grad) / 1e6
        # one copy of grad, one copy of momentum, one copy of denominator and model weights.
        # and 4 bytes for each float!
        mem_usage = model_size * 4 * 4 / 1000
        self.logger.info("Model size: %.2f M params", model_size)
        self.logger.info("Base memory usage, with model, grad and optim: %.2f GB", mem_usage)

    @property
    def autocast(self):
        """Convenient autocast (or not) using the solver configuration."""
        return TorchAutocast(enabled=self.cfg.autocast, device_type=self.device, dtype=self.autocast_dtype)

    def _get_state_source(self, name) -> flashy.state.StateDictSource:
        # Internal utility to get a state source from the solver
        return self.stateful.sources[name]

    @property
    def best_metric_name(self) -> tp.Optional[str]:
        """Metric name used to identify the best state. This metric should be stored in the metrics
        used on the stage for best state identification (most likely, `valid`). If None, then
        no best state is saved.
        """
        return None

    def register_best_state(self, *args: str):
        """Register state sources in `BestStateDictManager` to keep their best states along with their
        latest states. The best state will be used at evaluation stages instead of the latest states.

        Shortcut around `BestStateDictManager.register` method. You can pass any number of
        attribute, included nested attributes and those will be included into the checkpoints
        and automatically restored when `BaseSolver.restore` is called.
        """
        for name in args:
            state_source = self._get_state_source(name)
            assert name in self.stateful.sources, "Registered states in best should be registered in stateful first!"
            self.best_state.register(name, state_source)

    def register_ema(self, *args: str):
        """Register state sources for exponential moving average.

        The registered sources are used to instantiate a ModuleDictEMA instance.
        The ModuleDictEMA keeps a `nn.ModuleDict` module that is updated when self.ema.step() is called
        and swapped with the original state sources with self.swap_ema_state() method.

        Usage:
            self.register_ema('model')
        """
        assert self.ema is None, "Cannot register state source to already instantiated EMA."
        for name in args:
            self._ema_sources[name] = getattr(self, name)

    def wrap_with_fsdp(self, model: torch.nn.Module, *args, **kwargs):
        model = fsdp.wrap_with_fsdp(self.cfg.fsdp, model, *args, **kwargs)
        if isinstance(model, fsdp.FSDP):
            self._fsdp_modules.append(model)
        return model

    def update_best_state_from_stage(self, stage_name: str = 'valid'):
        """Update latest best state based on pending metrics of a given stage. This method relies
        on the `BestStateDictManager.update` method to update the best state_dict with latest weights
        if the registered states happen to match to the best performing setup.
        """
        if self.best_metric_name is None:
            # when no best metric is defined, the last state is always the best
            self._new_best_state = True
            self.logger.info("Updating best state with current state.")
        else:
            assert stage_name in self._pending_metrics, f"Metrics for stage {stage_name} not found."
            assert self.best_metric_name in self._pending_metrics[stage_name], \
                f"Best metric not found in {stage_name} metrics. Cannot register best state"
            current_score = self._pending_metrics[stage_name][self.best_metric_name]
            all_best_metric_scores = [
                past_metrics[stage_name][self.best_metric_name]
                for past_metrics in self.history
            ]
            all_best_metric_scores.append(current_score)
            best_score = min(all_best_metric_scores)
            self._new_best_state = current_score == best_score
            if self._new_best_state:
                old_best = min(all_best_metric_scores[:-1] + [float('inf')])
                self.logger.info(
                    f"New best state with {self.best_metric_name}={current_score:.3f} (was {old_best:.3f})")

        if self._new_best_state:
            if self.cfg.fsdp.use:
                # this will give an empty state dict on all ranks but the rank 0
                # which will have a copy in memory of the full model.
                with fsdp.switch_to_full_state_dict(self._fsdp_modules):
                    for name in self.best_state.states.keys():
                        state_source = self._get_state_source(name)
                        self.best_state.update(name, state_source)
                    # we save to a different dict.
                    self.fsdp_best_state.update(self.best_state.state_dict())
                # We cannot efficiently load fsdp_best_state when using FSDP,
                # so we have do do a second pass, with the local shards.
            for name in self.best_state.states.keys():
                state_source = self._get_state_source(name)
                self.best_state.update(name, state_source)

    def _load_new_state_dict(self, state_dict: dict) -> dict:
        old_states = {}
        for name, new_state in state_dict.items():
            state_source = self._get_state_source(name)
            old_states[name] = copy_state(state_source.state_dict())
            state_source.load_state_dict(new_state)
        return old_states

    @contextmanager
    def swap_best_state(self):
        self.logger.debug(f"Swapping to best state for: {', '.join(self.best_state.state_dict().keys())}")
        old_states = self._load_new_state_dict(self.best_state.state_dict())
        try:
            yield
        finally:
            self.logger.debug("Swapping back from best to original state")
            for name, old_state in old_states.items():
                state_source = self._get_state_source(name)
                state_source.load_state_dict(old_state)

    @contextmanager
    def swap_ema_state(self):
        if self.ema is None:
            yield
        else:
            ema_state_dict = self.ema.state_dict()['state']
            self.logger.debug(f"Swapping to EMA state for: {', '.join(ema_state_dict.keys())}")
            old_states = self._load_new_state_dict(ema_state_dict)
            try:
                yield
            finally:
                self.logger.debug("Swapping back from EMA state to original state")
                for name, old_state in old_states.items():
                    state_source = self._get_state_source(name)
                    state_source.load_state_dict(old_state)

    @property
    def is_training(self):
        return self.current_stage == 'train'

    def log_model_summary(self, model: nn.Module):
        """Log model summary, architecture and size of the model."""
        self.logger.info(model)
        mb = sum(p.numel() for p in model.parameters()) * 4 / 2 ** 20
        self.logger.info("Size: %.1f MB", mb)

    @abstractmethod
    def build_model(self):
        """Method to implement to initialize model."""
        ...

    def initialize_ema(self):
        """Initialize exponential moving average with the registered sources.
        EMA object is created if the optim.ema.model.decay value is non-null.
        """
        from .builders import get_ema
        self.ema = get_ema(self._ema_sources, self.cfg.optim.ema)
        if self.ema is None:
            self.logger.info('No EMA on the model.')
        else:
            assert self.cfg.optim.ema.updates > 0
            self.logger.info(
                f'Initializing EMA on the model with decay = {self.ema.decay}'
                f' every {self.cfg.optim.ema.updates} updates'
            )

    @abstractmethod
    def build_dataloaders(self):
        """Method to implement to initialize dataloaders."""
        ...

    @abstractmethod
    def show(self):
        """Method to log any information without running the job."""
        ...

    @property
    def log_updates(self):
        # convenient access to log updates
        return self._log_updates

    def checkpoint_path(self, **kwargs):
        kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
        return self.folder / checkpoint.checkpoint_name(**kwargs)

    def epoch_checkpoint_path(self, epoch: int, **kwargs):
        kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
        return self.folder / checkpoint.checkpoint_name(str(epoch), **kwargs)

    def checkpoint_path_with_name(self, name: str, **kwargs):
        kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
        return self.folder / checkpoint.checkpoint_name(name=name, **kwargs)

    def save_checkpoints(self):
        """Save checkpoint, optionally keeping a copy for a given epoch."""
        is_sharded = self.cfg.fsdp.use
        if not flashy.distrib.is_rank_zero() and not is_sharded:
            return
        self.logger.info("Model hash: %s", model_hash(self.model))
        state = self.state_dict()
        epoch = self.epoch - 1  # pushing metrics will increase the epoch in Flashy, so we do -1 here

        # save minimal state_dict as new checkpoint every X epoch
        if self.cfg.checkpoint.save_every:
            if epoch % self.cfg.checkpoint.save_every == 0:
                minimal_state = state
                if self.cfg.checkpoint.keep_every_states is not None and len(self.cfg.checkpoint.keep_every_states) > 0:
                    minimal_state = {
                        name: source for name, source in state.items()
                        if name in self.cfg.checkpoint.keep_every_states
                    }
                epoch_checkpoint_path = self.epoch_checkpoint_path(epoch)
                checkpoint.save_checkpoint(minimal_state, epoch_checkpoint_path, is_sharded)

        # save checkpoint as latest checkpoint
        if self.cfg.checkpoint.save_last:
            last_checkpoint_path = self.checkpoint_path()
            checkpoint.save_checkpoint(state, last_checkpoint_path, is_sharded)

        # flush any stale checkpoint to reduce disk footprint
        checkpoint.flush_stale_checkpoints(self.checkpoint_path())

    def load_from_pretrained(self, name: str) -> dict:
        raise NotImplementedError("Solver does not provide a way to load pretrained models.")

    def load_checkpoints(self, load_best: bool = False, ignore_state_keys: tp.List[str] = []) -> tp.Optional[dict]:
        """Load last checkpoint or the one specified in continue_from.

        Args:
            load_best (bool): Whether to load from best state dict or not.
                Best state dict is always used when not loading the current xp.
            ignore_state_keys (list of str): List of sources to ignore when loading the state, e.g. `optimizer`.
        Returns:
            state (dict, optional): The loaded state dictionary.
        """
        # load checkpoints from xp folder or cfg.continue_from
        is_sharded = self.cfg.fsdp.use
        load_from_path: tp.Optional[Path] = None
        checkpoint_source: tp.Optional[checkpoint.CheckpointSource] = None

        if load_best:
            self.logger.info("Trying to load state_dict from best state.")

        state: tp.Optional[dict] = None
        rank0_checkpoint_path = self.checkpoint_path(use_fsdp=False)
        current_checkpoint_path = self.checkpoint_path()
        _pretrained_prefix = '//pretrained/'
        continue_pretrained = (self.cfg.continue_from or '').startswith(_pretrained_prefix)
        if rank0_checkpoint_path.exists():
            self.logger.info(f"Loading existing checkpoint: {current_checkpoint_path}")
            load_from_path = current_checkpoint_path
            checkpoint.check_sharded_checkpoint(current_checkpoint_path, rank0_checkpoint_path)
            checkpoint_source = checkpoint.CheckpointSource.CURRENT_XP
        elif self.cfg.continue_from and not continue_pretrained:
            self.logger.info(f"Continuing from provided checkpoint: {self.cfg.continue_from}")
            # we're always continuing from consolidated checkpoints: self.cfg.use_fsdp and not continue_best
            load_from_path = checkpoint.resolve_checkpoint_path(self.cfg.continue_from, use_fsdp=False)
            if load_from_path is None:
                self.logger.error('Could not resolve the continue_from checkpoint %s', self.cfg.continue_from)
                raise RuntimeError(f'Could not resolve continue_from checkpoint {self.cfg.continue_from}')
            checkpoint_source = checkpoint.CheckpointSource.OTHER

        if load_from_path is not None:
            state = checkpoint.load_checkpoint(load_from_path, is_sharded)
        elif continue_pretrained:
            self.logger.info("Loading a pretrained model. Ignoring 'load_best' and 'ignore_state_keys' params.")
            state = self.load_from_pretrained(self.cfg.continue_from[len(_pretrained_prefix):])
            checkpoint_source = checkpoint.CheckpointSource.PRETRAINED
            load_best = True

        # checkpoints are not from the current xp, we only retrieve the best state
        if checkpoint_source is not None and checkpoint_source != checkpoint.CheckpointSource.CURRENT_XP:
            assert state is not None
            self.logger.info("Checkpoint source is not the current xp: Load state_dict from best state.")
            load_best = True
            state = {key: state[key] for key in self._continue_best_source_keys if key in state}
            # loaded checkpoints are FSDP checkpoints: we're reading the best state
            # from FSDP and we drop the regular best_state
            if 'fsdp_best_state' in state and state['fsdp_best_state']:
                state.pop('best_state', None)
                self.logger.info("... Loaded checkpoint has FSDP best state")
            # FSDP is enabled in the solver, if the loaded checkpoints do not have FSDP support
            # then we're initializing FSDP best state with the regular best state
            elif self.cfg.fsdp.use:
                if 'fsdp_best_state' not in state or not state['fsdp_best_state']:
                    # we swap non-FSDP checkpoints best_state to FSDP-compatible best state
                    state['fsdp_best_state'] = state.pop('best_state')
                    self.logger.info("... Loaded checkpoint does not have FSDP best state. Use regular best state")

        if state is not None:
            if load_best:
                self.logger.info("Ignoring keys when loading best %r", ignore_state_keys)
                for key in set(ignore_state_keys):
                    if key in state:
                        state.pop(key)
                has_best_state = 'best_state' in state or 'fsdp_best_state' in state
                assert has_best_state, ("Trying to load best state but neither 'best_state'",
                                        " or 'fsdp_best_state' found in checkpoints.")
            self.load_state_dict(state)

        # for FSDP, let's make extra sure nothing bad happened with out of sync
        # checkpoints across workers.
        epoch = float(self.epoch)
        avg_epoch = flashy.distrib.average_metrics({'epoch': epoch})['epoch']
        if avg_epoch != epoch:
            raise RuntimeError(
                f"Inconsistent loading of checkpoints happened, our epoch is {epoch} "
                f"but average of epochs is {avg_epoch}, at least one gpu must have a "
                "different epoch number.")

        # on load_best, properly reinitialize state_dict, best states and ema
        # otherwise we load from the current xp and don't alter anything
        if load_best:
            self.logger.info("Loading state_dict from best state.")
            if not self.cfg.fsdp.use and self.fsdp_best_state:
                # loading from an FSDP checkpoint but with FSDP deactivated
                self.logger.info("... Loading from FSDP best state dict.")
                self.best_state.load_state_dict(self.fsdp_best_state)

            # if load_best, we permanently override the regular state_dict with the best state
            if self.cfg.fsdp.use:
                self.logger.info("FSDP is used, loading from FSDP best state.")
                with fsdp.switch_to_full_state_dict(self._fsdp_modules):
                    # this might be really fragile but okay for now.
                    self.load_state_dict(self.fsdp_best_state)
            else:
                # we permanently swap the stateful objects to their best state
                self._load_new_state_dict(self.best_state.state_dict())

            # the EMA modules should also be instantiated with best state.
            # the easiest way to do so is to reinitialize a new EMA with best state loaded.
            if self.ema is not None:
                self.logger.info("Re-initializing EMA from best state")
                self.initialize_ema()

            if self.cfg.fsdp.use:
                self.logger.info("Re-initializing best state after using FSDP best state.")
                for name in self.best_state.states.keys():
                    state_source = self._get_state_source(name)
                    self.best_state.update(name, state_source)

        return state

    def restore(self, load_best: bool = False, replay_metrics: bool = False,
                ignore_state_keys: tp.List[str] = []) -> bool:
        """Restore the status of a solver for a given xp.

        Args:
            load_best (bool): if `True`, load the best state from the checkpoint.
            replay_metrics (bool): if `True`, logs all the metrics from past epochs.
            ignore_state_keys (list of str): list of sources to ignore when loading the state, e.g. `optimizer`.
        """
        self.logger.info("Restoring weights and history.")
        restored_checkpoints = self.load_checkpoints(load_best, ignore_state_keys)

        self.logger.info("Model hash: %s", model_hash(self.model))

        if replay_metrics and len(self.history) > 0:
            self.logger.info("Replaying past metrics...")
            for epoch, stages in enumerate(self.history):
                for stage_name, metrics in stages.items():
                    # We manually log the metrics summary to the result logger
                    # as we don't want to add them to the pending metrics
                    self.result_logger._log_summary(stage_name, metrics, step=epoch + 1, step_name='epoch',
                                                    formatter=self.get_formatter(stage_name))
        return restored_checkpoints is not None

    def commit(self, save_checkpoints: bool = True):
        """Commit metrics to dora and save checkpoints at the end of an epoch."""
        # we override commit to introduce more complex checkpoint saving behaviors
        self.history.append(self._pending_metrics)  # This will increase self.epoch
        if save_checkpoints:
            self.save_checkpoints()
        self._start_epoch()
        if flashy.distrib.is_rank_zero():
            self.xp.link.update_history(self.history)

    def run_epoch(self):
        """Run a single epoch with all stages.

        Metrics for a given stage are stored in _pending_metrics and committed by the solver afterwards.
        Children solvers can extend this method with custom behavior, e.g.:

            def run_epoch(self):
                ... # custom code
                super().run_epoch()
                ... # custom code
        """
        self.run_stage('train', self.train)
        with torch.no_grad():
            with self.swap_ema_state():
                self.run_stage('valid', self.valid)
                # the best state is updated with EMA states if available
                self.update_best_state_from_stage('valid')
            with self.swap_best_state():
                if self.should_run_stage('evaluate'):
                    self.run_stage('evaluate', self.evaluate)
                if self.should_run_stage('generate'):
                    self.run_stage('generate', with_rank_rng()(self.generate))

    def run(self):
        """Training loop."""
        assert len(self.state_dict()) > 0
        self.restore(replay_metrics=True)  # load checkpoint and replay history
        self.log_hyperparams(dict_from_config(self.cfg))
        for epoch in range(self.epoch, self.cfg.optim.epochs + 1):
            if self.should_stop_training():
                return
            self.run_epoch()
            # Commit will send the metrics to Dora and save checkpoints by default.
            self.commit()

    def should_stop_training(self) -> bool:
        """Check whether we should stop training or not."""
        return self.epoch > self.cfg.optim.epochs

    def should_run_stage(self, stage_name) -> bool:
        """Check whether we want to run the specified stages."""
        stage_every = self.cfg[stage_name].get('every', None)
        is_last_epoch = self.epoch == self.cfg.optim.epochs
        is_epoch_every = (stage_every and self.epoch % stage_every == 0)
        return is_last_epoch or is_epoch_every

    @abstractmethod
    def run_step(self, idx: int, batch: tp.Any, metrics: dict):
        """Perform one training or valid step on a given batch."""
        ...

    def common_train_valid(self, dataset_split: str, **kwargs: tp.Any):
        """Common logic for train and valid stages."""
        self.model.train(self.is_training)

        loader = self.dataloaders[dataset_split]
        # get a different order for distributed training, otherwise this will get ignored
        if flashy.distrib.world_size() > 1 \
           and isinstance(loader.sampler, torch.utils.data.distributed.DistributedSampler):
            loader.sampler.set_epoch(self.epoch)
        updates_per_epoch = self.train_updates_per_epoch if self.is_training else len(loader)
        if self.cfg.benchmark_no_load:
            self.logger.warning("Fake loading for benchmarking: re-using first batch")
            batch = next(iter(loader))
            loader = [batch] * updates_per_epoch  # type: ignore
        lp = self.log_progress(self.current_stage, loader, total=updates_per_epoch, updates=self.log_updates)
        average = flashy.averager()  # epoch wise average
        instant_average = flashy.averager()  # average between two logging
        metrics: dict = {}

        with self.profiler, self.deadlock_detect:  # profiler will only run for the first 20 updates.
            for idx, batch in enumerate(lp):
                self.deadlock_detect.update('batch')
                if idx >= updates_per_epoch:
                    break
                metrics = {}
                metrics = self.run_step(idx, batch, metrics)
                self.deadlock_detect.update('step')
                # run EMA step
                if self.ema is not None and self.is_training and (idx + 1) % self.cfg.optim.ema.updates == 0:
                    self.logger.debug("EMA model step")
                    self.ema.step()
                self.deadlock_detect.update('ema')
                self.profiler.step()
                instant_metrics = instant_average(metrics)
                if lp.update(**instant_metrics):
                    instant_average = flashy.averager()  # reset averager between two logging
                metrics = average(metrics)  # epoch wise average
                self.deadlock_detect.update('end_batch')

        metrics = flashy.distrib.average_metrics(metrics, updates_per_epoch)
        return metrics

    def train(self):
        """Train stage."""
        return self.common_train_valid('train')

    def valid(self):
        """Valid stage."""
        return self.common_train_valid('valid')

    @abstractmethod
    def evaluate(self):
        """Evaluate stage."""
        ...

    @abstractmethod
    def generate(self):
        """Generate stage."""
        ...

    def run_one_stage(self, stage_name: str):
        """Run only the specified stage.
        This method is useful to only generate samples from a trained experiment
        or rerun the validation or evaluation stages.
        """
        fn = {
            'generate': with_rank_rng()(self.generate),
            'evaluate': self.evaluate,
            'valid': self.valid,
        }
        if stage_name not in fn:
            raise ValueError(f'Trying to run stage {stage_name} is not supported.')
        assert len(self.state_dict()) > 0
        self._start_epoch()
        with torch.no_grad(), self.swap_best_state():
            self.run_stage(stage_name, fn[stage_name])
        if not self.cfg.execute_inplace:
            self.commit(save_checkpoints=False)

    @staticmethod
    def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
                                 device: tp.Optional[str] = None, autocast: bool = True,
                                 batch_size: tp.Optional[int] = None,
                                 override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
                                 **kwargs):
        """Mostly a convenience function around audiocraft.train.get_solver_from_sig,
        populating all the proper param, deactivating EMA, FSDP, loading the best state,
        basically all you need to get a solver ready to "play" with in single GPU mode
        and with minimal memory overhead.

        Args:
            sig (str): signature to load.
            dtype (str or None): potential dtype, as a string, i.e. 'float16'.
            device (str or None): potential device, as a string, i.e. 'cuda'.
            override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
        """
        from audiocraft import train
        our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
        our_override_cfg['autocast'] = autocast
        if dtype is not None:
            our_override_cfg['dtype'] = dtype
        if device is not None:
            our_override_cfg['device'] = device
        if batch_size is not None:
            our_override_cfg['dataset'] = {'batch_size': batch_size}
        if override_cfg is None:
            override_cfg = {}
        override_cfg = omegaconf.OmegaConf.merge(
            omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg))  # type: ignore
        solver = train.get_solver_from_sig(
            sig, override_cfg=override_cfg,
            load_best=True, disable_fsdp=True,
            ignore_state_keys=['optimizer', 'ema'], **kwargs)
        solver.model.eval()
        return solver