def _slice(self, X, X2):
"""
Slice the correct dimensions for use in the kernel, as indicated by
`self.active_dims`.
:param X: Input 1 (NxD).
:param X2: Input 2 (MxD), may be None.
:return: Sliced X, X2, (Nxself.input_dim).
"""
if isinstance(self.active_dims, slice):
X = X[:, self.active_dims]
if X2 is not None:
X2 = X2[:, self.active_dims]
else:
X = tf.transpose(tf.gather(tf.transpose(X), self.active_dims))
if X2 is not None:
X2 = tf.transpose(tf.gather(tf.transpose(X2), self.active_dims))
input_dim_shape = tf.shape(X)[1]
input_dim = tf.convert_to_tensor(self.input_dim, dtype=settings.tf_int)
with tf.control_dependencies([tf.assert_equal(input_dim_shape, input_dim)]):
X = tf.identity(X)
return X, X2
评论列表
文章目录