topn.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def get_best(self, n):
    """Return the indices and values of the n highest scores in the TopN."""

    def refresh_shortlist():
      """Update the shortlist with the highest scores in id_to_score."""
      new_scores, new_ids = tf.nn.top_k(self.id_to_score, self.shortlist_size)
      smallest_new_score = tf.reduce_min(new_scores)
      new_length = tf.reduce_sum(
          tf.to_int32(tf.greater(new_scores, tf.float32.min)))
      u1 = self.sl_ids.assign(
          tf.to_int64(tf.concat(0, [[new_length], new_ids])))
      u2 = self.sl_scores.assign(
          tf.concat(0, [[smallest_new_score], new_scores]))
      self.last_ops = [u1, u2]
      return tf.group(u1, u2)

    # We only need to refresh the shortlist if n is greater than the
    # current shortlist size (which is stored in sl_ids[0]).
    with tf.control_dependencies(self.last_ops):
      cond_op = tf.cond(n > self.sl_ids[0], refresh_shortlist, tf.no_op)
      with tf.control_dependencies([cond_op]):
        topk_values, topk_indices = tf.nn.top_k(
            self.sl_scores, tf.minimum(n, tf.to_int32(self.sl_ids[0])))
        # topk_indices are the indices into the shortlist, we want to return
        # the indices into id_to_score
        gathered_indices = tf.gather(self.sl_ids, topk_indices)
        return gathered_indices, topk_values
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号