train_mixgan.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:MIX-plus-GAN 作者: yz-ignescent 项目源码 文件源码
def get_discriminator(self):
        ''' specify discriminator D0 '''
        """
        disc0_layers = [LL.InputLayer(shape=(self.args.batch_size, 3, 32, 32))]
        disc0_layers.append(LL.GaussianNoiseLayer(disc0_layers[-1], sigma=0.05))
        disc0_layers.append(dnn.Conv2DDNNLayer(disc0_layers[-1], 96, (3,3), pad=1, W=Normal(0.02), nonlinearity=nn.lrelu))
        disc0_layers.append(nn.batch_norm(dnn.Conv2DDNNLayer(disc0_layers[-1], 96, (3,3), pad=1, stride=2, W=Normal(0.02), nonlinearity=nn.lrelu))) # 16x16
        disc0_layers.append(LL.DropoutLayer(disc0_layers[-1], p=0.1))
        disc0_layers.append(nn.batch_norm(dnn.Conv2DDNNLayer(disc0_layers[-1], 192, (3,3), pad=1, W=Normal(0.02), nonlinearity=nn.lrelu)))
        disc0_layers.append(nn.batch_norm(dnn.Conv2DDNNLayer(disc0_layers[-1], 192, (3,3), pad=1, stride=2, W=Normal(0.02), nonlinearity=nn.lrelu))) # 8x8
        disc0_layers.append(LL.DropoutLayer(disc0_layers[-1], p=0.1))
        disc0_layers.append(nn.batch_norm(dnn.Conv2DDNNLayer(disc0_layers[-1], 192, (3,3), pad=0, W=Normal(0.02), nonlinearity=nn.lrelu))) # 6x6
        disc0_layer_shared = LL.NINLayer(disc0_layers[-1], num_units=192, W=Normal(0.02), nonlinearity=nn.lrelu) # 6x6
        disc0_layers.append(disc0_layer_shared)

        disc0_layer_z_recon = LL.DenseLayer(disc0_layer_shared, num_units=50, W=Normal(0.02), nonlinearity=None)
        disc0_layers.append(disc0_layer_z_recon) # also need to recover z from x

        disc0_layers.append(LL.GlobalPoolLayer(disc0_layer_shared))
        disc0_layer_adv = LL.DenseLayer(disc0_layers[-1], num_units=10, W=Normal(0.02), nonlinearity=None)
        disc0_layers.append(disc0_layer_adv)

        return disc0_layers, disc0_layer_adv, disc0_layer_z_recon
        """
        disc_x_layers = [LL.InputLayer(shape=(None, 3, 32, 32))]
        disc_x_layers.append(LL.GaussianNoiseLayer(disc_x_layers[-1], sigma=0.2))
        disc_x_layers.append(dnn.Conv2DDNNLayer(disc_x_layers[-1], 96, (3,3), pad=1, W=Normal(0.01), nonlinearity=nn.lrelu))
        disc_x_layers.append(nn.batch_norm(dnn.Conv2DDNNLayer(disc_x_layers[-1], 96, (3,3), pad=1, stride=2, W=Normal(0.01), nonlinearity=nn.lrelu)))
        disc_x_layers.append(LL.DropoutLayer(disc_x_layers[-1], p=0.5))
        disc_x_layers.append(nn.batch_norm(dnn.Conv2DDNNLayer(disc_x_layers[-1], 192, (3,3), pad=1, W=Normal(0.01), nonlinearity=nn.lrelu)))
        disc_x_layers.append(nn.batch_norm(dnn.Conv2DDNNLayer(disc_x_layers[-1], 192, (3,3), pad=1, stride=2, W=Normal(0.01), nonlinearity=nn.lrelu)))
        disc_x_layers.append(LL.DropoutLayer(disc_x_layers[-1], p=0.5))
        disc_x_layers.append(nn.batch_norm(dnn.Conv2DDNNLayer(disc_x_layers[-1], 192, (3,3), pad=0, W=Normal(0.01), nonlinearity=nn.lrelu)))
        disc_x_layers_shared = LL.NINLayer(disc_x_layers[-1], num_units=192, W=Normal(0.01), nonlinearity=nn.lrelu)
        disc_x_layers.append(disc_x_layers_shared)

        disc_x_layer_z_recon = LL.DenseLayer(disc_x_layers_shared, num_units=self.args.z0dim, nonlinearity=None)
        disc_x_layers.append(disc_x_layer_z_recon) # also need to recover z from x

        # disc_x_layers.append(nn.MinibatchLayer(disc_x_layers_shared, num_kernels=100))
        disc_x_layers.append(LL.GlobalPoolLayer(disc_x_layers_shared))
        disc_x_layer_adv = LL.DenseLayer(disc_x_layers[-1], num_units=10, W=Normal(0.01), nonlinearity=None)
        disc_x_layers.append(disc_x_layer_adv)

        #output_before_softmax_x = LL.get_output(disc_x_layer_adv, x, deterministic=False)
        #output_before_softmax_gen = LL.get_output(disc_x_layer_adv, gen_x, deterministic=False)

        # temp = LL.get_output(gen_x_layers[-1], deterministic=False, init=True)
        # temp = LL.get_output(disc_x_layers[-1], x, deterministic=False, init=True)
        # init_updates = [u for l in LL.get_all_layers(gen_x_layers)+LL.get_all_layers(disc_x_layers) for u in getattr(l,'init_updates',[])]
        return disc_x_layers, disc_x_layer_adv, disc_x_layer_z_recon
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号