mellowmax.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:chainerrl 作者: chainer 项目源码 文件源码
def maximum_entropy_mellowmax(values, omega=1., beta_min=-10, beta_max=10):
    """Maximum entropy mellowmax policy function.

    This function provides a categorical distribution whose expectation matches
    the one of mellowmax function while maximizing its entropy.

    See: http://arxiv.org/abs/1612.05628

    Args:
        values (Variable or ndarray):
            Input values. Mellowmax is taken along the second axis.
        omega (float):
            Parameter of mellowmax.
        beta_min (float):
            Minimum value of beta, used in Brent's algorithm.
        beta_max (float):
            Maximum value of beta, used in Brent's algorithm.
    Returns:
        outputs (Variable)
    """
    xp = chainer.cuda.get_array_module(values)
    mm = mellowmax(values, axis=1)

    # Advantage: Q - mellowmax(Q)
    batch_adv = values - F.broadcast_to(F.expand_dims(mm, 1), values.shape)
    # Move data to CPU because we use Brent's algorithm in scipy
    batch_adv = chainer.cuda.to_cpu(batch_adv.data)
    batch_beta = np.empty(mm.shape, dtype=np.float32)

    # Beta is computed as the root of this function
    def f(y, adv):
        return np.sum(np.exp(y * adv) * adv)

    for idx in np.ndindex(mm.shape):
        idx_full = idx[:1] + (slice(None),) + idx[1:]
        adv = batch_adv[idx_full]
        try:
            beta = scipy.optimize.brentq(
                f, a=beta_min, b=beta_max, args=(adv,))
        except ValueError:
            beta = 0
        batch_beta[idx] = beta

    return F.softmax(xp.expand_dims(xp.asarray(batch_beta), 1) * values)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号