test_mellowmax.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:chainerrl 作者: chainer 项目源码 文件源码
def check_forward(self, x_data):
        xp = cuda.get_array_module(x_data)
        y = maximum_entropy_mellowmax(x_data)
        self.assertEqual(y.data.dtype, self.dtype)

        print('y', y.data)

        # Outputs must be positive
        xp.testing.assert_array_less(xp.zeros_like(y.data), y.data)

        # Sums must be ones
        sums = xp.sum(y.data, axis=1)
        testing.assert_allclose(sums, xp.ones_like(sums))

        # Expectations must be equal to memllowmax's outputs
        testing.assert_allclose(
            xp.sum(y.data * x_data, axis=1), mellowmax(x_data, axis=1).data)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号