def _argmax(self, input_tensor, dimension, c):
"""
a constrainable version of tf.argmax
Parameters:
-----------
input_tensor: Tensor
dimension: Tensor
c: Tensor
The constraints tensor
A tensor of 0s and 1s where 1s represent the elements the reduction
should be made on, and 0s represent discarded elements
"""
with self.session.graph.as_default():
min_values = tf.reduce_min(input_tensor, reduction_indices=[dimension,], keep_dims=True)
not_c = tf.abs(c - 1)
return tf.argmax(input_tensor * c + not_c * min_values, dimension)
评论列表
文章目录