def get_optimizer(self):
if self.opt == 'sgd':
return k_opt.SGD(lr=self.learning_rate, momentum=self.momentum)
if self.opt == 'rmsprop':
return k_opt.RMSprop(lr=self.learning_rate)
if self.opt == 'adagrad':
return k_opt.Adagrad(lr=self.learning_rate)
if self.opt == 'adadelta':
return k_opt.Adadelta(lr=self.learning_rate)
if self.opt == 'adam':
return k_opt.Adam(lr=self.learning_rate)
raise Exception('Invalid optimization function - %s' % self.opt)
评论列表
文章目录