def build(self):
self.keys = F.normalize(random_uniform((self.memory_size, self.key_dim), -0.001, 0.001, cuda=True), dim=1)
self.keys_var = ag.Variable(self.keys, requires_grad=False)
self.values = torch.zeros(self.memory_size, 1).long().cuda()
self.age = torch.zeros(self.memory_size, 1).cuda()
评论列表
文章目录