def test_create_optimizer(self):
"""Test if create optimizer does work with tf optimizers."""
optimizer_config = {'learning_rate': 0.1}
# test missing required entry `class`
self.assertRaises(AssertionError, create_optimizer, optimizer_config)
optimizer_config['class'] = 'tensorflow.python.training.gradient_descent.GradientDescentOptimizer'
with tf.Session().as_default():
# test if the optimizer is created correctlyW
optimizer = create_optimizer(optimizer_config)
self.assertIsInstance(optimizer, tf.train.GradientDescentOptimizer)
# test if learning_rate variable is created with the correct value
lr_tensor = tf.get_default_graph().get_tensor_by_name('learning_rate:0')
tf.get_default_session().run(tf.global_variables_initializer())
self.assertAlmostEqual(lr_tensor.eval(), 0.1)
optimizer_config2 = {'learning_rate': 0.1, 'class': 'tensorflow.python.training.momentum.MomentumOptimizer'}
# test missing required argument (momentum in this case)
with tf.Graph().as_default():
self.assertRaises(TypeError, create_optimizer, optimizer_config2)
评论列表
文章目录