def get_S_loss_hao(mean_x, logcov_x, qv_alpha, qv_beta, qeta_mu, qeta_sigma, epsilon = 1e-8):
sigma_px = 1.0
S1 = tf.digamma(qv_alpha) - tf.digamma(qv_alpha + qv_beta)
S2 = tf.cumsum(tf.digamma(qv_beta) - tf.digamma(qv_alpha + qv_beta))
mean_x_expand = tf.expand_dims(mean_x, 1)
logcov_x_expand = tf.expand_dims(logcov_x, 1)
qeta_mu_expand = tf.expand_dims(tf.transpose(qeta_mu), 0)
qeta_sigma_expand = tf.expand_dims(tf.transpose(qeta_sigma), 0)
S3 = 0.5 * tf.reduce_sum(1 + logcov_x_expand - 2 * tf.log(sigma_px) \
- (tf.exp(logcov_x_expand) + tf.square(qeta_sigma_expand) \
+ tf.square(mean_x_expand - qeta_mu_expand)) / tf.square(sigma_px), 2)
S = S3 + tf.concat(0, [S1, [0.0]]) + tf.concat(0, [[0.0], S2])
# get the variational distribution q(z)
S_max = tf.reduce_max(S, reduction_indices=1)
S_whiten = S - tf.expand_dims(S_max, 1)
qz = tf.exp(S_whiten) / tf.expand_dims(tf.reduce_sum(tf.exp(S_whiten), 1), 1)
# Summarize the S loss
# S_loss = -tf.reduce_sum(tf.log(tf.reduce_sum(tf.exp(S), 1)))
S_loss = -tf.reduce_sum(S_max) - tf.reduce_sum(tf.log(tf.reduce_sum(tf.exp(S - tf.expand_dims(S_max, 1)), 1) + epsilon))
return S_loss, qz, S
评论列表
文章目录