topn.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def remove(self, ids):
    """Remove the ids (and their associated scores) from the TopN."""
    with tf.control_dependencies(self.last_ops):
      scatter_op = tf.scatter_update(
          self.id_to_score,
          ids,
          tf.ones_like(
              ids, dtype=tf.float32) * tf.float32.min)
      # We assume that removed ids are almost always in the shortlist,
      # so it makes no sense to hide the Op behind a tf.cond
      shortlist_ids_to_remove, new_length = self.ops.top_n_remove(self.sl_ids,
                                                                  ids)
      u1 = tf.scatter_update(
          self.sl_ids, tf.concat(0, [[0], shortlist_ids_to_remove]),
          tf.concat(0, [new_length,
                        tf.ones_like(shortlist_ids_to_remove) * -1]))
      u2 = tf.scatter_update(
          self.sl_scores,
          shortlist_ids_to_remove,
          tf.float32.min * tf.ones_like(
              shortlist_ids_to_remove, dtype=tf.float32))
      self.last_ops = [scatter_op, u1, u2]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号