test_deterministic_policy.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:chainerrl 作者: chainer 项目源码 文件源码
def _test_call(self, gpu):
        # This method only check if a given model can receive random input
        # data and return output data with the correct interface.
        nonlinearity = getattr(F, self.nonlinearity)
        min_action = np.full((self.action_size,), -0.01, dtype=np.float32)
        max_action = np.full((self.action_size,), 0.01, dtype=np.float32)
        model = self._make_model(
            n_input_channels=self.n_input_channels,
            action_size=self.action_size,
            bound_action=self.bound_action,
            min_action=min_action,
            max_action=max_action,
            nonlinearity=nonlinearity,
        )

        batch_size = 7
        x = np.random.rand(
            batch_size, self.n_input_channels).astype(np.float32)
        if gpu >= 0:
            model.to_gpu(gpu)
            x = chainer.cuda.to_gpu(x)
            min_action = chainer.cuda.to_gpu(min_action)
            max_action = chainer.cuda.to_gpu(max_action)
        y = model(x)
        self.assertTrue(isinstance(
            y, chainerrl.distribution.ContinuousDeterministicDistribution))
        a = y.sample()
        self.assertTrue(isinstance(a, chainer.Variable))
        self.assertEqual(a.shape, (batch_size, self.action_size))
        self.assertEqual(chainer.cuda.get_array_module(a),
                         chainer.cuda.get_array_module(x))
        if self.bound_action:
            self.assertTrue((a.data <= max_action).all())
            self.assertTrue((a.data >= min_action).all())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号