def resetSize(self, *args):
if len(args) == 1 and isinstance(args[0], torch.Size):
self.size = args[0]
else:
self.size = torch.Size(args)
self.numElements = 1
inferdim = False
for i in range(len(self.size)):
szi = self.size[i]
if szi >= 0:
self.numElements = self.numElements * self.size[i]
else:
assert szi == -1
assert not inferdim
inferdim = True
return self
评论列表
文章目录