def space_shape(space): """return the shape of tensor expected for a given space""" if isinstance(space, spaces.Discrete): return [space.n] else: return space.shape