def add(self, keys, values):
if self.curr_capacity >= 5:
dist, ind = self.tree.query(np.pad(keys, ((0,0),(0,1)), 'constant', constant_values=1.0), k=5)
#dist, ind = self.tree.query(keys, k=50)
for i, ind_ in enumerate(ind):
stren = 1 - self.alpha
self.weights[ind_] = self.weights[ind_] * stren
for i, _ in enumerate(keys):
low_w = 1.0
if self.curr_capacity >= self.capacity:
# find the LRU entry
old_index = np.argmin(self.weights)
low_w = min(low_w, self.weights[old_index])
index = old_index
else:
index = self.curr_capacity
self.curr_capacity+=1
self.states[index] = keys[i]
self.q_values[index] = values[i]
self.weights[index] = 1.0
self.tree = KDTree(np.concatenate((self.states[:self.curr_capacity], np.expand_dims(self.weights[:self.curr_capacity], axis=1)),axis=1))
#self.tree = KDTree(self.states[:self.curr_capacity])
EC_functions.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录