def remove(self, ids):
"""Remove the ids (and their associated scores) from the TopN."""
with tf.control_dependencies(self.last_ops):
scatter_op = tf.scatter_update(
self.id_to_score,
ids,
tf.ones_like(
ids, dtype=tf.float32) * tf.float32.min)
# We assume that removed ids are almost always in the shortlist,
# so it makes no sense to hide the Op behind a tf.cond
shortlist_ids_to_remove, new_length = self.ops.top_n_remove(self.sl_ids,
ids)
u1 = tf.scatter_update(
self.sl_ids, tf.concat(0, [[0], shortlist_ids_to_remove]),
tf.concat(0, [new_length,
tf.ones_like(shortlist_ids_to_remove) * -1]))
u2 = tf.scatter_update(
self.sl_scores,
shortlist_ids_to_remove,
tf.float32.min * tf.ones_like(
shortlist_ids_to_remove, dtype=tf.float32))
self.last_ops = [scatter_op, u1, u2]
评论列表
文章目录