def get_optimizer(config):
if(config['optimizer'] == 'rmsprop'):
opti = optimizers.rmsprop(lr=config['learning_rate'],
clipvalue=config['grad_clip'],
decay=config['decay_rate'])
return opti
elif(config['optimizer'] == 'adadelta'):
opti = optimizers.adadelta(lr=config['learning_rate'],
clipvalue=config['grad_clip'])
return opti
elif(config['optimizer'] == 'sgd'):
opti = optimizers.sgd(lr=config['learning_rate'],
momentum=config['momentum'],
decay=config['learning_rate_decay'])
return opti
else:
raise StandardError('optimizer name error')
评论列表
文章目录