function.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:instance_normalization_chainer 作者: crcrpar 项目源码 文件源码
def backward(self, inputs, grad_outputs):
        x, gamma, beta = inputs[:3]
        gy = grad_outputs[0]
        head_ndim = gamma.ndim + 1
        expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim)
        m = gamma.dtype.type(x.size // gamma.size)
        axis = (2, 3)
        gamma_beta_axis = (0, 2, 3)
        mean_var_expander = (Ellipsis, None, None)
        xp = cuda.get_array_module(x)

        gbeta = gy.sum(axis=gamma_beta_axis)
        ggamma = (gy * self.x_hat).sum(axis=gamma_beta_axis)
        if xp is numpy:
            gx = (gamma / self.std)[mean_var_expander] * (
                gy - (self.x_hat * ggamma[mean_var_expander] + gbeta[mean_var_expander]) / m)
        else:
            inv_m = numpy.float32(1) / m
            gx = cuda.elementwise(
                'T gy, T x_hat, T gamma, T std, T ggamma, T gbeta, \
                T inv_m',
                'T gx',
                'gx = (gamma / std) * (gy - (x_hat * ggamma + gbeta) * \
                inv_m)',
                'bn_bwd')(gy, self.x_hat, gamma[expander],
                          self.std[mean_var_expander], ggamma[mean_var_expander],
                          gbeta[mean_var_expander], inv_m)
        return gx, ggamma, gbeta
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号