def exKxz_pairwise(self, Z, Xmu, Xcov):
with tf.control_dependencies([
tf.assert_equal(tf.shape(Xmu)[1], tf.constant(self.input_dim, settings.tf_int),
message="Currently cannot handle slicing in exKxz."),
tf.assert_equal(tf.shape(Xmu), tf.shape(Xcov)[1:3], name="assert_Xmu_Xcov_shape")
]):
Xmu = tf.identity(Xmu)
N = tf.shape(Xmu)[0] - 1
Xmum = Xmu[:-1, :]
Xmup = Xmu[1:, :]
op = tf.expand_dims(Xmum, 2) * tf.expand_dims(Xmup, 1) + Xcov[1, :-1, :, :] # NxDxD
return self.variance * tf.matmul(tf.tile(tf.expand_dims(Z, 0), (N, 1, 1)), op)
评论列表
文章目录