def trainModel():
# Create models
print("Creating VAE...")
vae, _, _ = getModels()
vae.compile(optimizer='rmsprop', loss=VAELoss)
print("Loading dataset...")
X_train, X_test = loadDataset()
X_train = X_train
X_test = X_test
# Train the VAE on dataset
print("Training VAE...")
runID = "VAE - ZZZ"
tb = TensorBoard(log_dir='/tmp/logs/'+runID)
vae.fit(X_train, X_train, shuffle=True, nb_epoch=nbEpoch, batch_size=batchSize, validation_data=(X_test, X_test), callbacks=[tb])
# Serialize weights to HDF5
print("Saving weights...")
vae.save_weights(modelsPath+"model.h5")
# Generates images and plots
评论列表
文章目录