WeightNorm.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:OpenNMT-py 作者: OpenNMT 项目源码 文件源码
def forward(self, x, init=False):
        if init is True:
            # out_channels, in_channels // groups, * kernel_size
            self.V.data.copy_(torch.randn(self.V.data.size()
                                          ).type_as(self.V.data) * 0.05)
            v_norm = self.V.data / self.V.data.view(self.out_channels, -1)\
                .norm(2, 1).view(self.out_channels, *(
                    [1] * (len(self.kernel_size) + 1))).expand_as(self.V.data)
            x_init = F.conv2d(x, Variable(v_norm), None, self.stride,
                              self.padding, self.dilation, self.groups).data
            t_x_init = x_init.transpose(0, 1).contiguous().view(
                self.out_channels, -1)
            m_init, v_init = t_x_init.mean(1).squeeze(
                1), t_x_init.var(1).squeeze(1)
            # out_features
            scale_init = self.init_scale / \
                torch.sqrt(v_init + 1e-10)
            self.g.data.copy_(scale_init)
            self.b.data.copy_(-m_init * scale_init)
            scale_init_shape = scale_init.view(
                1, self.out_channels, *([1] * (len(x_init.size()) - 2)))
            m_init_shape = m_init.view(
                1, self.out_channels, *([1] * (len(x_init.size()) - 2)))
            x_init = scale_init_shape.expand_as(
                x_init) * (x_init - m_init_shape.expand_as(x_init))
            self.V_avg.copy_(self.V.data)
            self.g_avg.copy_(self.g.data)
            self.b_avg.copy_(self.b.data)
            return Variable(x_init)
        else:
            v, g, b = get_vars_maybe_avg(
                self, ['V', 'g', 'b'], self.training,
                polyak_decay=self.polyak_decay)

            scalar = torch.norm(v.view(self.out_channels, -1), 2, 1)
            if len(scalar.size()) == 2:
                scalar = g / scalar.squeeze(1)
            else:
                scalar = g / scalar

            w = scalar.view(self.out_channels, *
                            ([1] * (len(v.size()) - 1))).expand_as(v) * v

            x = F.conv2d(x, w, b, self.stride,
                         self.padding, self.dilation, self.groups)
            return x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号