def get_S_loss(alpha, beta, mean_x, logcov_x, mean_eta, logcov_eta, sigma2, epsilon=1e-8):
mean_x_pad = tf.expand_dims(mean_x, 1)
logcov_x_pad = tf.expand_dims(logcov_x, 1)
mean_eta_pad = tf.expand_dims(mean_eta, 0)
logcov_eta_pad = tf.expand_dims(logcov_eta, 0)
S1 = tf.digamma(alpha) - tf.digamma(alpha + beta)
S2 = tf.cumsum(tf.digamma(beta) - tf.digamma(alpha + beta))
S = 0.5 * tf.reduce_sum( \
1 + logcov_x_pad - math.log(sigma2) \
- (tf.exp(logcov_x_pad) + tf.exp(logcov_eta_pad) + tf.square(mean_x_pad - mean_eta_pad)) / sigma2 , 2 \
) \
+ tf.concat(0, [S1, tf.constant([0.0])]) + tf.concat(0, [tf.constant([0.0]), S2])
assignments = tf.argmax(S, dimension=1)
S_max = tf.reduce_max(S, reduction_indices=1)
S_loss = -tf.reduce_sum(S_max) - tf.reduce_sum(tf.log(tf.reduce_sum(tf.exp(S - tf.expand_dims(S_max, 1)), reduction_indices = 1) + epsilon))
return assignments, S_loss
评论列表
文章目录