def update(self, query, y, y_hat, y_hat_indices):
batch_size, dims = query.size()
# 1) Untouched: Increment memory by 1
self.age += 1
# Divide batch by correctness
result = torch.squeeze(torch.eq(y_hat, torch.unsqueeze(y.data, dim=1))).float()
incorrect_examples = torch.squeeze(torch.nonzero(1-result))
correct_examples = torch.squeeze(torch.nonzero(result))
incorrect = len(incorrect_examples.size()) > 0
correct = len(correct_examples.size()) > 0
# 2) Correct: if V[n1] = v
# Update Key k[n1] <- normalize(q + K[n1]), Reset Age A[n1] <- 0
if correct:
correct_indices = y_hat_indices[correct_examples]
correct_keys = self.keys[correct_indices]
correct_query = query.data[correct_examples]
new_correct_keys = F.normalize(correct_keys + correct_query, dim=1)
self.keys[correct_indices] = new_correct_keys
self.age[correct_indices] = 0
# 3) Incorrect: if V[n1] != v
# Select item with oldest age, Add random offset - n' = argmax_i(A[i]) + r_i
# K[n'] <- q, V[n'] <- v, A[n'] <- 0
if incorrect:
incorrect_size = incorrect_examples.size()[0]
incorrect_query = query.data[incorrect_examples]
incorrect_values = y.data[incorrect_examples]
age_with_noise = self.age + random_uniform((self.memory_size, 1), -self.age_noise, self.age_noise, cuda=True)
topk_values, topk_indices = torch.topk(age_with_noise, incorrect_size, dim=0)
oldest_indices = torch.squeeze(topk_indices)
self.keys[oldest_indices] = incorrect_query
self.values[oldest_indices] = incorrect_values
self.age[oldest_indices] = 0
评论列表
文章目录