def _check_rank_type(self, X, M):
"""Check the rank of the input tensors."""
data_rank = len(X.shape)
mask_rank = len(M.shape)
assert data_rank == 3
assert mask_rank == 2
assert tf.as_dtype(M.dtype).is_bool