def test_constant(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = 'raise'
x = T.constant(numpy.random.rand(2, 3), dtype=config.floatX)
y = theano.shared(numpy.random.rand(3, 6).astype(config.floatX),
'y')
# should work
z = T.dot(x, y)
assert hasattr(z.tag, 'test_value')
f = theano.function([], z)
assert _allclose(f(), z.tag.test_value)
# this test should fail
x = T.constant(numpy.random.rand(2, 4), dtype=config.floatX)
self.assertRaises(ValueError, T.dot, x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
test_compute_test_value.py 文件源码
python
阅读 36
收藏 0
点赞 0
评论 0
评论列表
文章目录