def __init__(self):
initializer = initializers.HeNormal()
dis = Discriminator()
chainer.serializers.load_npz('result/dis_iter_500000.npz', dis)
super(DiscriminatorClassifier, self).__init__(
c0 = L.Convolution2D(1, 64, 4, stride=2, pad=1, initialW=dis.c0.W.data, initial_bias=dis.c0.b.data),
c1 = L.Convolution2D(64, 128, 4, stride=2, pad=1, initialW=dis.c1.W.data, initial_bias=dis.c1.b.data),
l2 = L.Linear(7*7*128, 10, initialW = initializer),
bn1 = L.BatchNormalization(128),
)
self.c0.disable_update()
# def __init__(self):
# initializer = initializers.HeNormal()
# super(DiscriminatorClassifier, self).__init__(
# c0 = L.Convolution2D(1, 64, 4, stride=2, pad=1, initialW=initializer),
# c1 = L.Convolution2D(64, 128, 4, stride=2, pad=1, initialW=initializer),
# l2 = L.Linear(7*7*128, 10, initialW = initializer),
# bn1 = L.BatchNormalization(128),
# )
#
评论列表
文章目录