bptd.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:bptd 作者: aschein 项目源码 文件源码
def _init_data(self, data, mask=None):
        if isinstance(data, np.ndarray):
            data = skt.sptensor(data.nonzero(),
                                data[data.nonzero()],
                                data.shape)
        assert isinstance(data, skt.sptensor)
        assert data.ndim == 4
        assert data.shape[0] == data.shape[1]
        V, A, T = data.shape[1:]
        self.n_actors = V
        self.n_actions = A
        self.n_timesteps = T

        if mask is not None:
            assert isinstance(mask, np.ndarray)
            assert (mask.ndim == 2) or (mask.ndim == 3)
            assert mask.shape[-2:] == (V, V)
            assert np.issubdtype(mask.dtype, np.integer)

        return data
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号