def checkAverager(self):
acc = utils.averager()
acc.add(Variable(torch.Tensor([1, 2])))
acc.add(Variable(torch.Tensor([[5, 6]])))
assert acc.val() == 3.5
acc = utils.averager()
acc.add(torch.Tensor([1, 2]))
acc.add(torch.Tensor([[5, 6]]))
assert acc.val() == 3.5
评论列表
文章目录