def test_copy_from_works_with_control_flow(self):
def graph_fn1(mode, x):
return plx.layers.Dense(units=1)(x)
def graph_fn2(mode, x):
return plx.layers.Dense(units=1, trainable=False)(x)
l1 = plx.FunctionModule(mode=plx.Modes.TRAIN, build_fn=graph_fn1, name='fn1')
l2 = plx.FunctionModule(mode=plx.Modes.TRAIN, build_fn=graph_fn2, name='fn2')
x = tf.placeholder(dtype=tf.float32, shape=[1, 1])
lx1 = l1(x)
lx2 = l2(x)
init_all_op = tf.global_variables_initializer()
def copy():
# note that we need to put this copy_op in a function otherwise it will always
# be evaluate no matter what the condition
return l2.copy_from(l1, tf.GraphKeys.GLOBAL_VARIABLES)
a = tf.placeholder(tf.int32, ())
cond = tf.cond(tf.equal(tf.mod(a, 5), 0), copy, lambda: tf.no_op())
assign_op = l1.get_variables()[0].assign_add([[1]])
group_op = tf.group(*[assign_op, cond])
with self.test_session() as sess:
sess.run(init_all_op)
# Check that initially they have different values
lx1_results = lx1.eval({x: [[1]]})
lx2_results = lx2.eval({x: [[1]]})
assert lx1_results[0] != lx2_results[0]
# Set condition to True 10 % 5 == 0
sess.run(cond, feed_dict={a: 10})
lx1_results = lx1.eval({x: [[1]]})
lx2_results = lx2.eval({x: [[1]]})
assert lx1_results[0] == lx2_results[0]
# Assign and Set condition to False 2 % 5 != 0
sess.run(group_op, feed_dict={a: 2})
lx1_results = lx1.eval({x: [[1]]})
lx2_results = lx2.eval({x: [[1]]})
assert lx1_results[0] != lx2_results[0]
评论列表
文章目录