def backward(self, gradE):
A, X, C = self.saved_variables
with torch.cuda.device_of(A):
gradA = Variable(A.data.new().resize_as_(A.data))
gradX = Variable(A.data.new().resize_as_(X.data))
gradC = Variable(A.data.new().resize_as_(C.data))
if isinstance(A.data, torch.cuda.FloatTensor):
with torch.cuda.device_of(A.data):
encoding_lib.Encoding_Float_aggregate_backward(gradA.data,
gradE.data, A.data, X.data, C.data)
elif isinstance(A.data, torch.cuda.DoubleTensor):
with torch.cuda.device_of(A.data):
encoding_lib.Encoding_Double_aggregate_backward(gradA.data,
gradE.data, A.data, X.data, C.data)
else:
raise RuntimeError('Unimplemented data type!')
gradX.data.copy_(torch.bmm(A, gradE).data)
gradC.data.copy_((-gradE*A.sum(1).unsqueeze(2)).sum(0).data)
return gradA, gradX, gradC
评论列表
文章目录