File size: 278 Bytes
47c46ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from torch.optim import Adam
from torch.optim.lbfgs import LBFGS
from .radam import RAdam


OPTIMIZER_MAP = {
    "adam": Adam,
    "radam": RAdam,
    "lbfgs": LBFGS,
}


def get_optimizer_class(optimizer_name):
    name = optimizer_name.lower()
    return OPTIMIZER_MAP[name]