def check_forward(self, xs):
y = chainerrl.functions.weighted_sum_arrays(xs, weights=self.weights)
correct_y = sum(x * w for x, w in zip(self.xs, self.weights))
gradient_check.assert_allclose(correct_y, cuda.to_cpu(y.data))
评论列表
文章目录