def exKxz(self, Z, Xmu, Xcov):
with tf.control_dependencies([
tf.assert_equal(tf.shape(Xmu)[1], tf.constant(self.input_dim, settings.int_type),
message="Currently cannot handle slicing in exKxz."),
tf.assert_equal(tf.shape(Xmu), tf.shape(Xcov)[:2], name="assert_Xmu_Xcov_shape")
]):
Xmu = tf.identity(Xmu)
N = tf.shape(Xmu)[0]
op = tf.expand_dims(Xmu, 2) * tf.expand_dims(Xmu, 1) + Xcov # NxDxD
return self.variance * tf.matmul(tf.tile(tf.expand_dims(Z, 0), (N, 1, 1)), op)
评论列表
文章目录