def normalize_to_unit_sum(x, EPS=1e-10): ''' Along the last dim ''' EPS = tf.constant(EPS, dtype=tf.float32) x = x + EPS x_sum = tf.reduce_sum(x, -1, keep_dims=True) x = tf.divide(x, x_sum) return x