def GumbelSoftmaxLogDensity(y, p, tau):
# EPS = tf.constant(1e-10)
k = tf.shape(y)[-1]
k = tf.cast(k, tf.float32)
# y = y + EPS
# y = tf.divide(y, tf.reduce_sum(y, -1, keep_dims=True))
y = normalize_to_unit_sum(y)
sum_p_over_y = tf.reduce_sum(tf.divide(p, tf.pow(y, tau)), -1)
logp = tf.lgamma(k)
logp = logp + (k - 1) * tf.log(tau)
logp = logp - k * tf.log(sum_p_over_y)
logp = logp + sum_p_over_y
return logp
评论列表
文章目录