def test_forward_invalid(self):
f = F.Linear(5, 5)
# OK
v = chainer.Variable(numpy.random.randn(1, 5).astype(numpy.float32))
result = f(v)
assert isinstance(result, chainer.Variable)
# Incorrect dtype
# in py3, numpy dtypes are represented as class
msg = """\
Invalid operation is performed in: LinearFunction \\(Forward\\)
Expect: in_types\\[0\\]\\.dtype == <(type|class) 'numpy\\.float32'>
Actual: float64 \\!= <(type|class) 'numpy\\.float32'>"""
v = chainer.Variable(numpy.random.randn(1, 5))
with six.assertRaisesRegex(self, chainer.utils.type_check.InvalidType,
msg):
f(v)
# Incorrect dim
msg = """\
Invalid operation is performed in: LinearFunction \\(Forward\\)
Expect: in_types\\[0\\]\\.ndim >= 2
Actual: 1 < 2"""
v = chainer.Variable(numpy.random.randn(5).astype(numpy.float32))
with six.assertRaisesRegex(self, chainer.utils.type_check.InvalidType,
msg):
f(v)
评论列表
文章目录