def get_learning_rate_policy_callback(lr_params):
if isinstance(lr_params, numbers.Real):
# If argument is real number, set policy to fixed and use given value as base_lr
lr_params = {'name': 'fixed', 'base_lr': lr_params}
# Check if lr_params contains all required parameters for selected policy.
if lr_params['name'] not in lrp.lr_policies:
raise NotImplementedError("Learning rate policy {lr_name} not supported."
"\nSupported policies are: {policies}".format(
lr_name=lr_params['name'],
policies=lrp.lr_policies.keys())
)
elif all([x in lr_params.keys() for x in lrp.lr_policies[lr_params['name']]['args']]):
return lrp.lr_policies[lr_params['name']]['obj'](lr_params)
else:
raise ValueError("Too few arguments provided to create policy {lr_name}."
"\nGiven: {lr_params}"
"\nExpected: {lr_args}".format(
lr_name=lr_params['name'],
lr_params=lr_params.keys(),
lr_args=lrp.lr_policies[lr_params['name']])
)
评论列表
文章目录