decoder.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:Semi-supervised_Neural_Network 作者: jibancanyang 项目源码 文件源码
def bn_hat_z_layers(self, hat_z_layers, z_pre_layers):
        # TODO: Calculate batchnorm using GPU Tensors.
        assert len(hat_z_layers) == len(z_pre_layers)
        hat_z_layers_normalized = []
        for i, (hat_z, z_pre) in enumerate(zip(hat_z_layers, z_pre_layers)):
            if self.use_cuda:
                ones = Variable(torch.ones(z_pre.size()[0], 1).cuda())
            else:
                ones = Variable(torch.ones(z_pre.size()[0], 1))
            mean = torch.mean(z_pre, 0)
            noise_var = np.random.normal(loc=0.0, scale=1 - 1e-10, size=z_pre.size())
            if self.use_cuda:
                var = np.var(z_pre.data.cpu().numpy() + noise_var, axis=0).reshape(1, z_pre.size()[1])
            else:
                var = np.var(z_pre.data.numpy() + noise_var, axis=0).reshape(1, z_pre.size()[1])
            var = Variable(torch.FloatTensor(var))
            if self.use_cuda:
                hat_z = hat_z.cpu()
                ones = ones.cpu()
                mean = mean.cpu()
            """
            print(z_pre.data.shape, mean.data.shape, ones.data.shape, hat_z.data.shape)
            print("=========== ")
            print(z_pre)
            print(mean)
            print(ones)
            print(hat_z)
            print("=========== ")
            """

            #ones = ones.unsqueeze(1)
            mean = mean.unsqueeze(0)

            #print(z_pre.data.shape, mean.data.shape, ones.data.shape, hat_z.data.shape)
            tempa = hat_z - ones.mm(mean)
            tempb = ones.mm(torch.sqrt(var + 1e-10))
            #hat_z_normalized = torch.div(hat_z - ones.mm(mean), ones.mm(torch.sqrt(var + 1e-10)))
            hat_z_normalized = torch.div(tempa, tempb)
            if self.use_cuda:
                hat_z_normalized = hat_z_normalized.cuda()
            hat_z_layers_normalized.append(hat_z_normalized)
        return hat_z_layers_normalized
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号