relaxed_onehot_categorical.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def _sample_n(self, n, seed=None):
    sample_shape = array_ops.concat(([n], array_ops.shape(self.logits)), 0)
    logits = self.logits * array_ops.ones(sample_shape)
    if logits.get_shape().ndims == 2:
      logits_2d = logits
    else:
      logits_2d = array_ops.reshape(logits, [-1, self.event_size])
    np_dtype = self.dtype.as_numpy_dtype()
    minval = np.nextafter(np_dtype(0), np_dtype(1))
    uniform = random_ops.random_uniform(shape=array_ops.shape(logits_2d),
                                        minval=minval,
                                        maxval=1,
                                        dtype=self.dtype,
                                        seed=seed)
    gumbel = - math_ops.log(- math_ops.log(uniform))
    noisy_logits = math_ops.div(gumbel + logits_2d, self.temperature)
    samples = nn_ops.log_softmax(noisy_logits)
    ret = array_ops.reshape(samples, sample_shape)
    return ret
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号