memory.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                       batch_size, use_recent_idx, intended_output):
        """Function that creates all the update ops."""
        mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size],
                                                       dtype=tf.float32))
        with tf.control_dependencies([mem_age_incr]):
            mem_age_upd = tf.scatter_update(
                self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32))

        mem_key_upd = tf.scatter_update(
            self.mem_keys, upd_idxs, upd_keys)
        mem_val_upd = tf.scatter_update(
            self.mem_vals, upd_idxs, upd_vals)

        if use_recent_idx:
            recent_idx_upd = tf.scatter_update(
                self.recent_idx, intended_output, upd_idxs)
        else:
            recent_idx_upd = tf.group()

        return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号