gram_matrix.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:fast-neural-style-pyfunt 作者: dnlcrl 项目源码 文件源码
def update_grad_input(self, x, grad_output):
        self.grad_input = np.zeros_like(x)
        if x.dims == 3:
            C, H, W = x.shape
            x_flat = x.view(C, H * W)
            self.buffer = np.dot(grad_output, x_flat)
            self.buffer += np.dot(grad_output.T, x_flat)
            self.grad_input = self.buffer.view(C, H, W)
        if x.dims == 4:
            N, C, H, W = x.shape
            x_flat = x.view(N, C, H * W)
            self.buffer = np.tensordot(grad_output, x_flat, 2)
            self.buffer += np.tensordot(grad_output.transpose(2, 3), x_flat, 2)
            self.grad_input = self.buffer.view(N, C, H, W)
        if self.normalize:
            self.buffer /= C * H * W
        return self.grad_input
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号