def train(epochs=1,batchsize=128):
batchCount = X_train.shape[0] / batchsize
print 'Epochs',epochs
print 'Bathc_size',batchsize
print 'Batches per epoch',batchCount
#range ande xrange the different is a list and a generator
for e in xrange(1,epochs+1):
print '-'*15 , 'Epoch %d' % e , '-'*15
for _ in tqdm(xrange(batchCount)):
#Get a random set of input noise and images
noise = np.random.normal(0,1,size=[batchsize,randomDim])
imageBatch = X_train[np.random.randint(0,X_train.shape[0],size=batchsize)]
#generate fake MNIST images
generatedImages = generator.predict(noise)
#Default is axis=0, equal to vstack is concate up and down
X = np.concatenate([imageBatch,generatedImages])
#Labels for generated and real data
yDis = np.ones(2*batchsize)
#one-sided label smoothing
yDis[:batchsize] = -1
#Train discriminator
discriminator.trainable = True
dloss = discriminator.train_on_batch(X,yDis)
#Train generator
noise = np.random.normal(0,1,size=[batchsize,randomDim])
yGen = np.ones(batchsize) * -1
discriminator.trainable = False
gloss = gan.train_on_batch(noise,yGen)
'''
d_weight = discriminator.get_weights()
d_weight = clip_weight(d_weight,-0.01,0.01)
discriminator.set_weights(d_weight)
'''
#Store loss of most recent batch from this epoch
Dloss.append(dloss)
Gloss.append(gloss)
if e == 1 or e % 5 == 0:
plotGeneratedImages(e)
saveModels(e)
plot_loss(e)
评论列表
文章目录