def _slice_cov(self, cov):
"""
Slice the correct dimensions for use in the kernel, as indicated by
`self.active_dims` for covariance matrices. This requires slicing the
rows *and* columns. This will also turn flattened diagonal
matrices into a tensor of full diagonal matrices.
:param cov: Tensor of covariance matrices (NxDxD or NxD).
:return: N x self.input_dim x self.input_dim.
"""
cov = tf.cond(tf.equal(tf.rank(cov), 2), lambda: tf.matrix_diag(cov), lambda: cov)
if isinstance(self.active_dims, slice):
cov = cov[..., self.active_dims, self.active_dims]
else:
cov_shape = tf.shape(cov)
covr = tf.reshape(cov, [-1, cov_shape[-1], cov_shape[-1]])
gather1 = tf.gather(tf.transpose(covr, [2, 1, 0]), self.active_dims)
gather2 = tf.gather(tf.transpose(gather1, [1, 0, 2]), self.active_dims)
cov = tf.reshape(tf.transpose(gather2, [2, 0, 1]),
tf.concat([cov_shape[:-2], [len(self.active_dims), len(self.active_dims)]], 0))
return cov
评论列表
文章目录