File size: 3,221 Bytes
6faeba1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from torch.optim.lr_scheduler import _LRScheduler


# This is rather suboptimal, because we need to import a protected class. Unfortunately, I don't see another way.


class ToucanWarmupScheduler(_LRScheduler):
    """
    A warmup scheduler that should be called after every batch.
    """

    def __init__(self, optimizer, peak_lr=0.0002, warmup_steps=20000, max_steps=200000, last_epoch=-1):
        self.warmup_steps = warmup_steps
        self.peak_lr = peak_lr
        self.max_steps = max_steps
        self.plateau = self.warmup_steps * 4
        self.last_lr = 0.0
        # __init__() must be invoked before setting field
        # because step() is also invoked in __init__()
        super().__init__(optimizer, last_epoch)

    def __repr__(self):
        return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"

    def get_lr(self):
        step_num = self.last_epoch + 1
        if step_num <= self.warmup_steps:
            lr = self.peak_lr * min(step_num / self.warmup_steps, 1.0)
            self.last_lr = lr
            return [lr for _ in self.base_lrs]
        elif step_num < self.warmup_steps + self.plateau:
            self.last_lr = self.peak_lr
            return [self.peak_lr for _ in self.base_lrs]
        else:
            scale = 1 - (((step_num - (self.warmup_steps + self.plateau)) / self.max_steps) / (self.max_steps / 10))
            self.last_lr = max(self.last_lr * scale, 1e-7)
            return [self.last_lr for _ in self.base_lrs]


class WarmupScheduler(_LRScheduler):
    """
    The WarmupLR scheduler
    This scheduler is almost same as NoamLR Scheduler except for following difference:
    NoamLR:
        lr = optimizer.lr * model_size ** -0.5
             * min(step ** -0.5, step * warmup_step ** -1.5)
    WarmupLR:
        lr = optimizer.lr * warmup_step ** 0.5
             * min(step ** -0.5, step * warmup_step ** -1.5)
    Note that the maximum lr equals to optimizer.lr in this scheduler.

    Taken from ESPnet
    """

    def __init__(self, optimizer, warmup_steps=25000, last_epoch=-1):
        self.warmup_steps = warmup_steps
        # __init__() must be invoked before setting field
        # because step() is also invoked in __init__()
        super().__init__(optimizer, last_epoch)

    def __repr__(self):
        return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"

    def get_lr(self):
        step_num = self.last_epoch + 1
        return [lr * self.warmup_steps ** 0.5 * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5) for lr in
                self.base_lrs]


if __name__ == '__main__':
    lrs = list()
    warmup_steps = 30000
    peak_lr = 0.0005
    max_steps = 800000
    plateau_size = warmup_steps * 5
    for step_num in range(max_steps):
        if step_num <= warmup_steps:
            lr = peak_lr * min(step_num / warmup_steps, 1.0)
            lrs.append(lr)
        elif step_num < warmup_steps + plateau_size:
            lrs.append(peak_lr)
        else:
            scale = 1 - (((step_num - (warmup_steps + plateau_size)) / max_steps) / (max_steps / 10))
            lrs.append(max(lrs[-1] * scale, 1e-7))
    import matplotlib.pyplot as plt

    plt.plot(lrs)
    plt.show()