MNIST_classify.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:GAN 作者: lyakaap 项目源码 文件源码
def __init__(self):
        initializer = initializers.HeNormal()
        dis = Discriminator()
        chainer.serializers.load_npz('result/dis_iter_500000.npz', dis)
        super(DiscriminatorClassifier, self).__init__(
            c0 = L.Convolution2D(1, 64, 4, stride=2, pad=1, initialW=dis.c0.W.data, initial_bias=dis.c0.b.data),
            c1 = L.Convolution2D(64, 128, 4, stride=2, pad=1, initialW=dis.c1.W.data, initial_bias=dis.c1.b.data),
            l2 = L.Linear(7*7*128, 10, initialW = initializer),
            bn1 = L.BatchNormalization(128),
        )
        self.c0.disable_update()

#    def __init__(self):
#        initializer = initializers.HeNormal()
#        super(DiscriminatorClassifier, self).__init__(
#            c0 = L.Convolution2D(1, 64, 4, stride=2, pad=1, initialW=initializer),
#            c1 = L.Convolution2D(64, 128, 4, stride=2, pad=1, initialW=initializer),
#            l2 = L.Linear(7*7*128, 10, initialW = initializer),
#            bn1 = L.BatchNormalization(128),
#        )
#
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号