transform.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:space-wrappers 作者: ngc92 项目源码 文件源码
def flatten(space):
    """
    Flattens a space, which means that for continuous spaces (Box)
    the space is reshaped to be of rank 1, and for multidimensional
    discrete spaces a single discrete action with an increased number
    of possible values is created.
    Please be aware that the latter can be potentially pathological in case
    the input space has many discrete actions, as the number of single discrete
    actions increases exponentially ("curse of dimensionality").
    :param gym.Space space: The space that will be flattened
    :return Transform: A transform object describing the transformation
            to the flattened space.
    :raises TypeError, if `space` is not a `gym.Space`.
            NotImplementedError, if the supplied space is neither `Box` nor
            `MultiDiscrete` or `MultiBinary`, and not recognized as
            an already flat space by `is_compound`.
    """
    # no need to do anything if already flat
    if not is_compound(space):
        return Transform(space, space, lambda x: x, lambda x: x)

    if isinstance(space, spaces.Box):
        shape = space.low.shape
        lo = space.low.flatten()
        hi = space.high.flatten()

        def convert(x):
            return np.reshape(x, shape)

        def back(x):
            return np.reshape(x, lo.shape)

        flat_space = spaces.Box(low=lo, high=hi)
        return Transform(original=space, target=flat_space, convert_from=convert, convert_to=back)

    elif isinstance(space, (spaces.MultiDiscrete, spaces.MultiBinary)):
        if isinstance(space, spaces.MultiDiscrete):
            ranges = [range(space.low[i], space.high[i]+1, 1) for i in range(space.num_discrete_space)]
        elif isinstance(space, spaces.MultiBinary):  # pragma: no branch
            ranges = [range(0, 2) for i in range(space.n)]
        prod   = itertools.product(*ranges)
        lookup = list(prod)
        inverse_lookup = {value: key for (key, value) in enumerate(lookup)}
        flat_space = spaces.Discrete(len(lookup))
        convert = lambda x: lookup[x]
        back    = lambda x: inverse_lookup[x]
        return Transform(original=space, target=flat_space, convert_from=convert, convert_to=back)

    raise NotImplementedError("Does not know how to flatten {}".format(type(space)))  # pragma: no cover


# rescale a continuous action space
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号