def softmax_and_sample(logits, temperature=1.):
"""
:temperature: default 1.
For high temperatures (temperature -> +Inf), all actions have nearly the same
probability and the lower the temperature, the more expected rewards affect
the probability. For a low temperature (temperature -> 0+), the probability of
the action with the highest expected reward (max operation) tends to 1.
"""
temperature = lib.floatX(temperature)
ZEROX = lib.floatX(0.)
assert temperature >= ZEROX, "`temperature` should be a non-negative value!"
old_shape = logits.shape
flattened_logits = logits.reshape((-1, logits.shape[logits.ndim-1]))
if temperature == ZEROX:
# Get max instead of (biased) sample.
# Equivalent to directly get the argmax but with this it's easier to
# extract the probabilities later on too.
samples = T.nnet.softmax(flattened_logits)
else: # > 0
flattened_logits /= temperature
samples = T.cast(
srng.multinomial(pvals=T.nnet.softmax(flattened_logits)),
theano.config.floatX
)
samples = samples.reshape(old_shape)
return T.argmax(samples, axis=samples.ndim-1)
评论列表
文章目录