def check_forward(self, c_data, x_data, y_data):
c = chainer.Variable(c_data)
x = chainer.Variable(x_data)
y = chainer.Variable(y_data)
z = functions.where(c, x, y)
self.assertEqual(x.data.shape, z.data.shape)
for i in numpy.ndindex(c.data.shape):
if c.data[i]:
self.assertEqual(x.data[i], z.data[i])
else:
self.assertEqual(y.data[i], z.data[i])
评论列表
文章目录