wgan.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:simple-wgan-with-minist 作者: ray0809 项目源码 文件源码
def train(epochs=1,batchsize=128):
    batchCount = X_train.shape[0] / batchsize
    print 'Epochs',epochs
    print 'Bathc_size',batchsize
    print 'Batches per epoch',batchCount
    #range ande xrange the different is a list and a generator
    for e in xrange(1,epochs+1):
        print '-'*15 , 'Epoch %d' % e , '-'*15
        for _ in tqdm(xrange(batchCount)):
            #Get a random set of input noise and images
            noise = np.random.normal(0,1,size=[batchsize,randomDim])
            imageBatch = X_train[np.random.randint(0,X_train.shape[0],size=batchsize)]

            #generate fake MNIST images
            generatedImages = generator.predict(noise)

            #Default is axis=0, equal to vstack  is concate up and down 
            X = np.concatenate([imageBatch,generatedImages])

            #Labels for generated and real data
            yDis = np.ones(2*batchsize)

            #one-sided label smoothing
            yDis[:batchsize] = -1

            #Train discriminator
            discriminator.trainable = True
            dloss = discriminator.train_on_batch(X,yDis)

            #Train generator
            noise = np.random.normal(0,1,size=[batchsize,randomDim])
            yGen = np.ones(batchsize) * -1
            discriminator.trainable = False
            gloss = gan.train_on_batch(noise,yGen)

            '''
            d_weight = discriminator.get_weights()
            d_weight = clip_weight(d_weight,-0.01,0.01)
            discriminator.set_weights(d_weight)
            '''
        #Store loss of most recent batch from this epoch
        Dloss.append(dloss)
        Gloss.append(gloss)

        if e == 1 or e % 5 == 0:
            plotGeneratedImages(e)
            saveModels(e)

    plot_loss(e)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号