test_basic.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_broadcasted(self):
        a = tensor.scalar(dtype='int32')
        b = tensor.matrix(dtype='float32')

        # Test when a is broadcastable
        A = 3
        B = numpy.asarray(numpy.random.rand(4, 4), dtype='float32')

        for m in self.modes:
            f = function([a, b], choose(a, b, mode=m))
            t_c = f(A, B)
            n_c = numpy.choose(A, B, mode=m)
            assert numpy.allclose(t_c, n_c)

        # Test when the result should be broadcastable
        b = theano.tensor.col(dtype='float32')
        B = numpy.asarray(numpy.random.rand(4, 1), dtype='float32')
        for m in self.modes:
            f = function([a, b], choose(a, b, mode=m))
            assert choose(a, b, mode=m).broadcastable[0]
            t_c = f(A, B)
            n_c = numpy.choose(A, B, mode=m)
            assert numpy.allclose(t_c, n_c)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号