vrnn_model.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:vrnn 作者: frhrdr 项目源码 文件源码
def gm_log_p(params_out, x_target, dim):
    """ computes log probability of target in Gaussian mixture with given parameters """
    mean_x, cov_x, pi_x_logit = params_out
    pi_x = tf.nn.softmax(pi_x_logit)
    mean_x = tf.transpose(mean_x, perm=[1, 0, 2])
    cov_x = tf.transpose(cov_x, perm=[1, 0, 2])
    pi_x = tf.transpose(pi_x, perm=[1, 0])

    x_diff = x_target - mean_x
    x_square = tf.reduce_sum((x_diff / cov_x) * x_diff, axis=[2])
    log_x_exp = -0.5 * x_square
    log_cov_x_det = tf.reduce_sum(tf.log(cov_x), axis=[2])
    log_x_norm = -0.5 * (dim * tf.log(2 * np.pi) + log_cov_x_det) + pi_x
    log_p = tf.reduce_logsumexp(log_x_norm + log_x_exp, axis=[0])
    return log_p, log_x_norm, log_x_exp, tf.abs(x_diff)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号