test_template_module.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:polyaxon 作者: polyaxon 项目源码 文件源码
def test_copy_from_works_with_control_flow(self):
        def graph_fn1(mode, x):
            return plx.layers.Dense(units=1)(x)

        def graph_fn2(mode, x):
            return plx.layers.Dense(units=1, trainable=False)(x)

        l1 = plx.FunctionModule(mode=plx.Modes.TRAIN, build_fn=graph_fn1, name='fn1')
        l2 = plx.FunctionModule(mode=plx.Modes.TRAIN, build_fn=graph_fn2, name='fn2')

        x = tf.placeholder(dtype=tf.float32, shape=[1, 1])

        lx1 = l1(x)
        lx2 = l2(x)

        init_all_op = tf.global_variables_initializer()

        def copy():
            # note that we need to put this copy_op in a function otherwise it will always
            # be evaluate no matter what the condition
            return l2.copy_from(l1, tf.GraphKeys.GLOBAL_VARIABLES)

        a = tf.placeholder(tf.int32, ())
        cond = tf.cond(tf.equal(tf.mod(a, 5), 0), copy, lambda: tf.no_op())
        assign_op = l1.get_variables()[0].assign_add([[1]])
        group_op = tf.group(*[assign_op, cond])

        with self.test_session() as sess:
            sess.run(init_all_op)

            # Check that initially they have different values
            lx1_results = lx1.eval({x: [[1]]})
            lx2_results = lx2.eval({x: [[1]]})

            assert lx1_results[0] != lx2_results[0]

            # Set condition to True 10 % 5 == 0
            sess.run(cond, feed_dict={a: 10})

            lx1_results = lx1.eval({x: [[1]]})
            lx2_results = lx2.eval({x: [[1]]})

            assert lx1_results[0] == lx2_results[0]

            # Assign and Set condition to False 2 % 5 != 0
            sess.run(group_op, feed_dict={a: 2})

            lx1_results = lx1.eval({x: [[1]]})
            lx2_results = lx2.eval({x: [[1]]})

            assert lx1_results[0] != lx2_results[0]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号