def _init_data(self, data, mask=None):
if isinstance(data, np.ndarray):
data = skt.sptensor(data.nonzero(),
data[data.nonzero()],
data.shape)
assert isinstance(data, skt.sptensor)
assert data.ndim == 4
assert data.shape[0] == data.shape[1]
V, A, T = data.shape[1:]
self.n_actors = V
self.n_actions = A
self.n_timesteps = T
if mask is not None:
assert isinstance(mask, np.ndarray)
assert (mask.ndim == 2) or (mask.ndim == 3)
assert mask.shape[-2:] == (V, V)
assert np.issubdtype(mask.dtype, np.integer)
return data
评论列表
文章目录