|
from functools import partial |
|
|
|
import torch |
|
from omegaconf import DictConfig |
|
|
|
|
|
def get_optimizer(optimizer_name, parameters, lr=None, config: DictConfig = None): |
|
param_dict = {} |
|
if optimizer_name == "adam": |
|
optimizer = partial(torch.optim.Adam, params=parameters) |
|
if lr is not None: |
|
optimizer = partial(torch.optim.Adam, params=parameters, lr=lr) |
|
if config.get('betas'): |
|
param_dict['betas'] = config.betas |
|
if config.get('weight_decay'): |
|
param_dict['weight_decay'] = config.weight_decay |
|
if config.get('eps'): |
|
param_dict['eps'] = config.eps |
|
elif optimizer_name == "adamw": |
|
optimizer = partial(torch.optim.AdamW, params=parameters) |
|
if lr is not None: |
|
optimizer = partial(torch.optim.AdamW, params=parameters, lr=lr) |
|
if config.get('betas'): |
|
param_dict['betas'] = config.betas |
|
if config.get('weight_decay'): |
|
param_dict['weight_decay'] = config.weight_decay |
|
if config.get('eps'): |
|
param_dict['eps'] = config.eps |
|
elif optimizer_name == "radam": |
|
optimizer = partial(torch.optim.RAdam, params=parameters) |
|
if lr is not None: |
|
optimizer = partial(torch.optim.RAdam, params=parameters, lr=lr) |
|
if config.get('betas'): |
|
param_dict['betas'] = config.betas |
|
if config.get('weight_decay'): |
|
param_dict['weight_decay'] = config.weight_decay |
|
elif optimizer_name == "sgd": |
|
optimizer = partial(torch.optim.SGD, params=parameters) |
|
if lr is not None: |
|
optimizer = partial(torch.optim.SGD, params=parameters, lr=lr) |
|
if config.get('momentum'): |
|
param_dict['momentum'] = config.momentum |
|
if config.get('weight_decay'): |
|
param_dict['weight_decay'] = config.weight_decay |
|
if config.get('nesterov'): |
|
param_dict['nesterov'] = config.nesterov |
|
else: |
|
raise NotImplementedError(f"Optimizer {optimizer_name} not implemented.") |
|
|
|
if len(param_dict.keys()) > 0: |
|
return optimizer(**param_dict) |
|
else: |
|
return optimizer() |
|
|