encoding.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:PyTorch-Encoding 作者: zhanghang1989 项目源码 文件源码
def backward(self, gradE):
        A, X, C = self.saved_variables
        with torch.cuda.device_of(A):
            gradA = Variable(A.data.new().resize_as_(A.data))
            gradX = Variable(A.data.new().resize_as_(X.data))
            gradC = Variable(A.data.new().resize_as_(C.data))
        if isinstance(A.data, torch.cuda.FloatTensor):
            with torch.cuda.device_of(A.data):
                encoding_lib.Encoding_Float_aggregate_backward(gradA.data, 
                    gradE.data, A.data, X.data, C.data)
        elif isinstance(A.data, torch.cuda.DoubleTensor):
            with torch.cuda.device_of(A.data):
                encoding_lib.Encoding_Double_aggregate_backward(gradA.data, 
                    gradE.data, A.data, X.data, C.data)
        else:
            raise RuntimeError('Unimplemented data type!')
        gradX.data.copy_(torch.bmm(A, gradE).data)
        gradC.data.copy_((-gradE*A.sum(1).unsqueeze(2)).sum(0).data)
        return gradA, gradX, gradC
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号