def gumbel_softmax(logits, temperature, hard=False):
"""Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits: [batch_size, n_class] unnormalized log-probs
temperature: non-negative scalar
hard: if True, take argmax, but differentiate w.r.t. soft sample y
Returns:
[batch_size, n_class] sample from the Gumbel-Softmax distribution.
If hard=True, then the returned sample will be one-hot, otherwise it will
be a probabilitiy distribution that sums to 1 across classes
"""
y = gumbel_softmax_sample(logits, temperature)
#if hard:
# k = tf.shape(logits)[-1]
# #y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)
# y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype)
# y = tf.stop_gradient(y_hard - y) + y
return y
评论列表
文章目录