main.py 文件源码

python
阅读 63 收藏 0 点赞 0 评论 0

项目:VariationalAutoEncoder 作者: despoisj 项目源码 文件源码
def trainModel():
    # Create models
    print("Creating VAE...")
    vae, _, _ = getModels()
    vae.compile(optimizer='rmsprop', loss=VAELoss)

    print("Loading dataset...")
    X_train, X_test = loadDataset()
    X_train = X_train
    X_test = X_test

    # Train the VAE on dataset
    print("Training VAE...")
    runID = "VAE - ZZZ"
    tb = TensorBoard(log_dir='/tmp/logs/'+runID)
    vae.fit(X_train, X_train, shuffle=True, nb_epoch=nbEpoch, batch_size=batchSize, validation_data=(X_test, X_test), callbacks=[tb])

    # Serialize weights to HDF5
    print("Saving weights...")
    vae.save_weights(modelsPath+"model.h5")

# Generates images and plots
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号