def compute_mask(self, inputs, mask=None):
dimension = K.ndim(inputs)
mask_tensor = K.any(K.not_equal(inputs, self.mask_value), axis=-1)
mask_base = K.any(mask_tensor, axis=1, keepdims=True)
for axis in range(2, dimension - 1):
mask_axis = K.any(mask_tensor, axis=axis, keepdims=True)
mask_base = tf.logical_and(mask_base, mask_axis)
return mask_base
评论列表
文章目录