transform_filter.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:GrouPy 作者: tscohen 项目源码 文件源码
def backward_gpu(self, inputs, grad_output):

        w, = inputs
        grad_rotated_w, = grad_output
        xp = cuda.get_array_module(w)

        # Gradient must be initialized with zeros,
        # because the kernel accumulates the gradient instead of overwriting it
        grad_w = xp.zeros_like(w)

        grad_index_group_func_kernel(
            grad_output=grad_rotated_w,
            T=self.T,
            U=self.U,
            V=self.V,
            grad_input=grad_w
        )

        return grad_w,
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号