test_state_action_q_function.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:chainerrl 作者: chainer 项目源码 文件源码
def _test_call_given_model(self, model, gpu):
        # This method only check if a given model can receive random input
        # data and return output data with the correct interface.
        batch_size = 7
        obs = np.random.rand(batch_size, self.n_dim_obs).astype(np.float32)
        action = np.random.rand(
            batch_size, self.n_dim_action).astype(np.float32)
        if gpu >= 0:
            model.to_gpu(gpu)
            obs = chainer.cuda.to_gpu(obs)
            action = chainer.cuda.to_gpu(action)
        y = model(obs, action)
        self.assertTrue(isinstance(y, chainer.Variable))
        self.assertEqual(y.shape, (batch_size, 1))
        self.assertEqual(chainer.cuda.get_array_module(y),
                         chainer.cuda.get_array_module(obs))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号