test_computations.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_sum_prod_broadcast(self):
        # placeholder
        a = tf.placeholder(tf.float32, shape=[3, 4, 5, 6])
        b = tf.placeholder(tf.float32, shape=[3, 4, 5])
        a_sum = tf.reduce_sum(a, reduction_indices=[0, 3])  # shape (4, 5)
        b_prod = tf.reduce_prod(b, reduction_indices=[0, 1])  # shape (5,)
        f = a_sum + b_prod + b  # (4, 5) + (5,) + (3, 4, 5) -> (3, 4, 5)

        # value
        feed_dict = dict()
        for x in [a, b]:
            feed_dict[x] = np.random.rand(*tf_obj_shape(x))

        # test
        self.run(f, tf_feed_dict=feed_dict)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号