def test_local_sumsqr2dot():
G = matrix('G')
W = matrix('W')
y = T.sqr(W.dimshuffle('x', 0, 1) * G.dimshuffle(0, 'x', 1)).sum(axis=(1, 2))
MODE = theano.compile.get_default_mode().including('local_sumsqr2dot')
f = function([W, G], y, mode=MODE)
w_val = numpy.random.rand(4, 3).astype(config.floatX)
g_val = numpy.random.rand(5, 3).astype(config.floatX)
f_val = f(w_val, g_val)
f_test = numpy.dot(numpy.square(g_val), numpy.square(w_val).sum(axis=0))
utt.assert_allclose(f_val, f_test)
assert any(isinstance(n.op, (tensor.basic.Dot, tensor.blas.Dot22,
tensor.blas.Gemv, tensor.blas_c.CGemv))
for n in f.maker.fgraph.toposort())
评论列表
文章目录