def insert(self, ids, scores):
"""Insert the ids and scores into the TopN."""
with tf.control_dependencies(self.last_ops):
scatter_op = tf.scatter_update(self.id_to_score, ids, scores)
larger_scores = tf.greater(scores, self.sl_scores[0])
def shortlist_insert():
larger_ids = tf.boolean_mask(tf.to_int64(ids), larger_scores)
larger_score_values = tf.boolean_mask(scores, larger_scores)
shortlist_ids, new_ids, new_scores = self.ops.top_n_insert(
self.sl_ids, self.sl_scores, larger_ids, larger_score_values)
u1 = tf.scatter_update(self.sl_ids, shortlist_ids, new_ids)
u2 = tf.scatter_update(self.sl_scores, shortlist_ids, new_scores)
return tf.group(u1, u2)
# We only need to insert into the shortlist if there are any
# scores larger than the threshold.
cond_op = tf.cond(
tf.reduce_any(larger_scores), shortlist_insert, tf.no_op)
with tf.control_dependencies([cond_op]):
self.last_ops = [scatter_op, cond_op]
评论列表
文章目录