def test_gemv_dimensions(self):
A = T.matrix('A')
x, y = T.vectors('x', 'y')
alpha = theano.shared(theano._asarray(1.0, dtype=config.floatX),
name='alpha')
beta = theano.shared(theano._asarray(1.0, dtype=config.floatX),
name='beta')
z = beta * y + alpha * T.dot(A, x)
f = theano.function([A, x, y], z)
# Matrix value
A_val = numpy.ones((5, 3), dtype=config.floatX)
# Different vector length
ones_3 = numpy.ones(3, dtype=config.floatX)
ones_4 = numpy.ones(4, dtype=config.floatX)
ones_5 = numpy.ones(5, dtype=config.floatX)
ones_6 = numpy.ones(6, dtype=config.floatX)
f(A_val, ones_3, ones_5)
f(A_val[::-1, ::-1], ones_3, ones_5)
self.assertRaises(ValueError, f, A_val, ones_4, ones_5)
self.assertRaises(ValueError, f, A_val, ones_3, ones_6)
self.assertRaises(ValueError, f, A_val, ones_4, ones_6)
# The following gemv tests were added in March 2011 by Ian Goodfellow
# and are based on the gemv tests from scipy
# http://projects.scipy.org/scipy/browser/trunk/scipy/linalg/tests/test_fblas.py?rev=6803
# NOTE: At the time these tests were written, theano did not have a
# conjugate function. If such a thing is ever added, the tests involving
# conjugate should be ported over as well.
评论列表
文章目录