Spaces:
Sleeping
Sleeping
import torch | |
from torch.optim.lr_scheduler import LRScheduler | |
class LinearSchedulerWithWarmup(LRScheduler): | |
def __init__( | |
self, | |
optimizer: torch.optim.Optimizer, | |
num_warmup_steps: int, | |
num_training_steps: int, | |
last_epoch: int = -1, | |
verbose: bool = False, | |
**kwargs, | |
): | |
self.num_warmup_steps = num_warmup_steps | |
self.num_training_steps = num_training_steps | |
super().__init__(optimizer, last_epoch, verbose) | |
def get_lr(self): | |
def scheduler_fn(current_step): | |
if current_step < self.num_warmup_steps: | |
return current_step / max(1, self.num_warmup_steps) | |
return max( | |
0.0, | |
float(self.num_training_steps - current_step) | |
/ float(max(1, self.num_training_steps - self.num_warmup_steps)), | |
) | |
return [base_lr * scheduler_fn(self.last_epoch) for base_lr in self.base_lrs] | |
class LinearScheduler(LRScheduler): | |
def __init__( | |
self, | |
optimizer: torch.optim.Optimizer, | |
num_training_steps: int, | |
last_epoch: int = -1, | |
verbose: bool = False, | |
**kwargs, | |
): | |
self.num_training_steps = num_training_steps | |
super().__init__(optimizer, last_epoch, verbose) | |
def get_lr(self): | |
def scheduler_fn(current_step): | |
# if current_step < self.num_warmup_steps: | |
# return current_step / max(1, self.num_warmup_steps) | |
return max( | |
0.0, | |
float(self.num_training_steps - current_step) | |
/ float(max(1, self.num_training_steps)), | |
) | |
return [base_lr * scheduler_fn(self.last_epoch) for base_lr in self.base_lrs] | |