def check_forward(self, data):
xs = [chainer.Variable(x) for x in data]
bxs = functions.broadcast(*xs)
# When len(xs) == 1, function returns a Variable object
if isinstance(bxs, chainer.Variable):
bxs = (bxs,)
for bx in bxs:
self.assertEqual(bx.data.shape, self.out_shape)
self.assertEqual(bx.data.dtype, self.dtype)
评论列表
文章目录