ops.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:mimicry.ai 作者: fizerkhan 项目源码 文件源码
def softmax_and_sample(logits, temperature=1.):
    """
    :temperature: default 1.
    For high temperatures (temperature -> +Inf), all actions have nearly the same
    probability and the lower the temperature, the more expected rewards affect
    the probability. For a low temperature (temperature -> 0+), the probability of
    the action with the highest expected reward (max operation) tends to 1.
    """
    temperature = lib.floatX(temperature)
    ZEROX = lib.floatX(0.)
    assert temperature >= ZEROX, "`temperature` should be a non-negative value!"
    old_shape = logits.shape
    flattened_logits = logits.reshape((-1, logits.shape[logits.ndim-1]))

    if temperature == ZEROX:
        # Get max instead of (biased) sample.
        # Equivalent to directly get the argmax but with this it's easier to
        # extract the probabilities later on too.
        samples = T.nnet.softmax(flattened_logits)
    else: # > 0
        flattened_logits /= temperature
        samples = T.cast(
            srng.multinomial(pvals=T.nnet.softmax(flattened_logits)),
            theano.config.floatX
        )
    samples = samples.reshape(old_shape)
    return T.argmax(samples, axis=samples.ndim-1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号