aggregate.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:PyTorch-Encoding 作者: zhanghang1989 项目源码 文件源码
def backward(self, gradE):
        A, X, C = self.saved_tensors
        with torch.cuda.device_of(A):
            gradA = A.new().resize_as_(A)
            gradX = A.new().resize_as_(X)
            gradC = A.new().resize_as_(C)
        if isinstance(A, torch.cuda.FloatTensor):
            with torch.cuda.device_of(A):
                encoding_lib.Encoding_Float_aggregateE_backward(gradA, 
                    gradE, A, X, C)
        elif isinstance(A, torch.cuda.DoubleTensor):
            with torch.cuda.device_of(A):
                encoding_lib.Encoding_Double_aggregateE_backward(gradA, 
                    gradE, A, X, C)
        else:
            raise RuntimeError('Unimplemented data type!')
        gradX.copy_(torch.bmm(A, gradE))
        gradC.copy_((-gradE*A.sum(1).unsqueeze(2)).sum(0))
        return gradA, gradX, gradC
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号