EC_functions.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:nn_q_learning_tensorflow 作者: EndingCredits 项目源码 文件源码
def add(self, keys, values):

        if self.curr_capacity >= 5:
          dist, ind = self.tree.query(np.pad(keys, ((0,0),(0,1)), 'constant', constant_values=1.0), k=5)
          #dist, ind = self.tree.query(keys, k=50)
          for i, ind_ in enumerate(ind):
            stren = 1 - self.alpha
            self.weights[ind_] = self.weights[ind_] * stren

        for i, _ in enumerate(keys):
            low_w = 1.0
            if self.curr_capacity >= self.capacity:
                # find the LRU entry
                old_index = np.argmin(self.weights)
                low_w = min(low_w, self.weights[old_index])
                index = old_index
            else:
                index = self.curr_capacity
                self.curr_capacity+=1

            self.states[index] = keys[i]
            self.q_values[index] = values[i]
            self.weights[index] = 1.0

        self.tree = KDTree(np.concatenate((self.states[:self.curr_capacity], np.expand_dims(self.weights[:self.curr_capacity], axis=1)),axis=1))
        #self.tree = KDTree(self.states[:self.curr_capacity])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号