def get_marginal_likelihood(yt, mean_yt, xt, s, alpha, beta, eta_mu, eta_sigma, eps, sigma_px, epsilon = 1e-8):
yt_expand = tf.expand_dims(yt, 0)
mean_yt = tf.reshape(mean_yt, [s, FLAGS.batch_size, 784])
xt = tf.reshape(xt, [1, s, FLAGS.batch_size, FLAGS.hidden_size])
# p_ygivenx = tf.reduce_prod(tf.pow(mean_yt, yt_expand) * tf.pow(1 - mean_yt, 1 - yt_expand), axis=2)
v = alpha / (alpha + beta)
pi = tf.concat(0, [v, [1.0]]) * tf.concat(0, [[1.0], tf.cumprod(1 - v)])
p_x = gaussian_mixture_pdf(eta_mu, tf.square(eta_sigma) + tf.square(sigma_px), xt, pi)
log_p_y_s = tf.reduce_sum(yt_expand * tf.log(mean_yt + epsilon) \
+ (1.0 - yt_expand) * tf.log(1.0 - mean_yt + epsilon), 2) \
+ tf.log(p_x) \
+ 0.5 * tf.reduce_sum(tf.square(eps), 2)
log_p_y_s_max = tf.reduce_max(log_p_y_s, reduction_indices=0)
log_p_y = tf.log(tf.reduce_mean(tf.exp(log_p_y_s - log_p_y_s_max), 0)) + log_p_y_s_max
return tf.reduce_mean(log_p_y)
# Taken from: https://github.com/tensorflow/tensorflow/issues/6322
评论列表
文章目录