def testUpdateOpFromCollection(self):
optimizers = [
"SGD", gradient_descent.GradientDescentOptimizer,
gradient_descent.GradientDescentOptimizer(learning_rate=0.1)
]
for optimizer in optimizers:
with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
x, var, loss, global_step = _setup_model()
update_var = variable_scope.get_variable(
"update", [], initializer=init_ops.constant_initializer(10))
update_op = state_ops.assign(update_var, 20)
ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op)
train = optimizers_lib.optimize_loss(
loss, global_step, learning_rate=0.1, optimizer=optimizer)
variables.global_variables_initializer().run()
session.run(train, feed_dict={x: 5})
var_value, update_var_value, global_step_value = session.run(
[var, update_var, global_step])
self.assertEqual(var_value, 9.5)
self.assertEqual(update_var_value, 20)
self.assertEqual(global_step_value, 1)
optimizers_test.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录