def check_shape_mismatch(self, x_data):
xp = cuda.get_array_module(x_data)
class DummyFunction(chainer.Function):
label = 'dummy_function'
def forward(self, inputs):
return xp.array(1, np.float32),
def backward(self, inputs, grads):
return xp.array([1, 2], np.float32),
x = chainer.Variable(x_data)
y = DummyFunction()(x)
with six.assertRaisesRegex(self, ValueError, 'dummy_function'):
y.backward()
评论列表
文章目录