test_opt.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_local_sumsqr2dot():
    G = matrix('G')
    W = matrix('W')

    y = T.sqr(W.dimshuffle('x', 0, 1) * G.dimshuffle(0, 'x', 1)).sum(axis=(1, 2))
    MODE = theano.compile.get_default_mode().including('local_sumsqr2dot')

    f = function([W, G], y, mode=MODE)

    w_val = numpy.random.rand(4, 3).astype(config.floatX)
    g_val = numpy.random.rand(5, 3).astype(config.floatX)

    f_val = f(w_val, g_val)
    f_test = numpy.dot(numpy.square(g_val), numpy.square(w_val).sum(axis=0))

    utt.assert_allclose(f_val, f_test)
    assert any(isinstance(n.op, (tensor.basic.Dot, tensor.blas.Dot22,
                                 tensor.blas.Gemv, tensor.blas_c.CGemv))
               for n in f.maker.fgraph.toposort())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号