vrnn_model.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:vrnn 作者: frhrdr 项目源码 文件源码
def sample(params, eps, dist='gauss'):
    """ utility function for sampling from distributions, given noise """
    if 'bin' in dist:
        logits = params[-1]
        params = params[:-1]
    if 'gauss' in dist:
        mean, cov = params
        s = mean + tf.sqrt(cov) * eps
    elif 'gm' in dist:
        means, covs, pi_logits = params
        choices = tf.multinomial(pi_logits, num_samples=1)
        batch_size = choices.get_shape()[0]
        ids = tf.constant(list(range(batch_size)), dtype=tf.int64, shape=(batch_size, 1))
        idx_tensor = tf.concat([ids, choices], axis=1)
        chosen_means = tf.gather_nd(means, idx_tensor)
        chosen_covs = tf.gather_nd(covs, idx_tensor)
        s = chosen_means + tf.sqrt(chosen_covs) * eps
    else:
        raise NotImplementedError

    if 'bin' in dist:
        sig = tf.sigmoid(logits)
        s = tf.concat([s, sig], axis=1)
    return s
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号