def get_best(self, n):
"""Return the indices and values of the n highest scores in the TopN."""
def refresh_shortlist():
"""Update the shortlist with the highest scores in id_to_score."""
new_scores, new_ids = tf.nn.top_k(self.id_to_score, self.shortlist_size)
smallest_new_score = tf.reduce_min(new_scores)
new_length = tf.reduce_sum(
tf.to_int32(tf.greater(new_scores, tf.float32.min)))
u1 = self.sl_ids.assign(
tf.to_int64(tf.concat(0, [[new_length], new_ids])))
u2 = self.sl_scores.assign(
tf.concat(0, [[smallest_new_score], new_scores]))
self.last_ops = [u1, u2]
return tf.group(u1, u2)
# We only need to refresh the shortlist if n is greater than the
# current shortlist size (which is stored in sl_ids[0]).
with tf.control_dependencies(self.last_ops):
cond_op = tf.cond(n > self.sl_ids[0], refresh_shortlist, tf.no_op)
with tf.control_dependencies([cond_op]):
topk_values, topk_indices = tf.nn.top_k(
self.sl_scores, tf.minimum(n, tf.to_int32(self.sl_ids[0])))
# topk_indices are the indices into the shortlist, we want to return
# the indices into id_to_score
gathered_indices = tf.gather(self.sl_ids, topk_indices)
return gathered_indices, topk_values
评论列表
文章目录