distributions.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:multimodal_varinf 作者: tmoer 项目源码 文件源码
def gumbel_softmax(logits, temperature, hard=False):
  """Sample from the Gumbel-Softmax distribution and optionally discretize.
  Args:
    logits: [batch_size, n_class] unnormalized log-probs
    temperature: non-negative scalar
    hard: if True, take argmax, but differentiate w.r.t. soft sample y
  Returns:
    [batch_size, n_class] sample from the Gumbel-Softmax distribution.
    If hard=True, then the returned sample will be one-hot, otherwise it will
    be a probabilitiy distribution that sums to 1 across classes
  """
  y = gumbel_softmax_sample(logits, temperature)
  #if hard:
  #  k = tf.shape(logits)[-1]
  #  #y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)
  #  y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype)
  #  y = tf.stop_gradient(y_hard - y) + y
  return y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号