def testWhileLoopProblem(self):
"""Tests L2L applied to problem with while loop."""
def while_loop_problem():
x = tf.get_variable("x", shape=[], initializer=tf.ones_initializer())
# Strange way of squaring the variable.
_, x_squared = tf.while_loop(
cond=lambda t, _: t < 1,
body=lambda t, x: (t + 1, x * x),
loop_vars=(0, x),
name="loop")
return x_squared
optimizer = meta.MetaOptimizer(net=dict(
net="CoordinateWiseDeepLSTM",
net_options={"layers": ()}))
minimize_ops = optimizer.meta_minimize(while_loop_problem, 3)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
train(sess, minimize_ops, 1, 2)
评论列表
文章目录