def prob_is_largest(self, Y, mu, var, gh_x, gh_w):
Y = tf.cast(Y, tf.int64)
# work out what the mean and variance is of the indicated latent function.
oh_on = tf.cast(tf.one_hot(tf.reshape(Y, (-1,)), self.num_classes, 1., 0.), settings.float_type)
mu_selected = tf.reduce_sum(oh_on * mu, 1)
var_selected = tf.reduce_sum(oh_on * var, 1)
# generate Gauss Hermite grid
X = tf.reshape(mu_selected, (-1, 1)) + gh_x * tf.reshape(
tf.sqrt(tf.clip_by_value(2. * var_selected, 1e-10, np.inf)), (-1, 1))
# compute the CDF of the Gaussian between the latent functions and the grid (including the selected function)
dist = (tf.expand_dims(X, 1) - tf.expand_dims(mu, 2)) / tf.expand_dims(
tf.sqrt(tf.clip_by_value(var, 1e-10, np.inf)), 2)
cdfs = 0.5 * (1.0 + tf.erf(dist / np.sqrt(2.0)))
cdfs = cdfs * (1 - 2e-4) + 1e-4
# blank out all the distances on the selected latent function
oh_off = tf.cast(tf.one_hot(tf.reshape(Y, (-1,)), self.num_classes, 0., 1.), settings.float_type)
cdfs = cdfs * tf.expand_dims(oh_off, 2) + tf.expand_dims(oh_on, 2)
# take the product over the latent functions, and the sum over the GH grid.
return tf.matmul(tf.reduce_prod(cdfs, reduction_indices=[1]), tf.reshape(gh_w / np.sqrt(np.pi), (-1, 1)))
评论列表
文章目录