def backward(self, gradE):
A, X, C = self.saved_tensors
with torch.cuda.device_of(A):
gradA = A.new().resize_as_(A)
gradX = A.new().resize_as_(X)
gradC = A.new().resize_as_(C)
if isinstance(A, torch.cuda.FloatTensor):
with torch.cuda.device_of(A):
encoding_lib.Encoding_Float_aggregateE_backward(gradA,
gradE, A, X, C)
elif isinstance(A, torch.cuda.DoubleTensor):
with torch.cuda.device_of(A):
encoding_lib.Encoding_Double_aggregateE_backward(gradA,
gradE, A, X, C)
else:
raise RuntimeError('Unimplemented data type!')
gradX.copy_(torch.bmm(A, gradE))
gradC.copy_((-gradE*A.sum(1).unsqueeze(2)).sum(0))
return gradA, gradX, gradC
评论列表
文章目录