def make_update_op(self, upd_idxs, upd_keys, upd_vals,
batch_size, use_recent_idx, intended_output):
"""Function that creates all the update ops."""
mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size],
dtype=tf.float32))
with tf.control_dependencies([mem_age_incr]):
mem_age_upd = tf.scatter_update(
self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32))
mem_key_upd = tf.scatter_update(
self.mem_keys, upd_idxs, upd_keys)
mem_val_upd = tf.scatter_update(
self.mem_vals, upd_idxs, upd_vals)
if use_recent_idx:
recent_idx_upd = tf.scatter_update(
self.recent_idx, intended_output, upd_idxs)
else:
recent_idx_upd = tf.group()
return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd)
评论列表
文章目录