movielens_vae_test.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:VAE-MF-TensorFlow 作者: arongdari 项目源码 文件源码
def train_test_validation():
    M = read_dataset()

    num_rating = np.count_nonzero(M)
    idx = np.arange(num_rating)
    np.random.seed(1)
    np.random.shuffle(idx)

    train_idx = idx[:int(0.85 * num_rating)]
    valid_idx = idx[int(0.85 * num_rating):int(0.90 * num_rating)]
    test_idx = idx[int(0.90 * num_rating):]

    for hidden_encoder_dim, hidden_decoder_dim, latent_dim, learning_rate, batch_size, reg_param, vae in itertools.product(hedims, hddims, ldims, lrates, bsizes, regs, vaes):
        result_path = "{0}_{1}_{2}_{3}_{4}_{5}_{6}".format(
            hidden_encoder_dim, hidden_decoder_dim, latent_dim, learning_rate, batch_size, reg_param, vae)
        if not os.path.exists(result_path + "/model.ckpt.index"):
            config = tf.ConfigProto()
            config.gpu_options.allow_growth=True
            with tf.Session(config=config) as sess:
                model = VAEMF(sess, num_user, num_item,
                              hidden_encoder_dim=hidden_encoder_dim, hidden_decoder_dim=hidden_decoder_dim,
                              latent_dim=latent_dim, learning_rate=learning_rate, batch_size=batch_size, reg_param=reg_param, vae=vae)
                print("Train size={0}, Validation size={1}, Test size={2}".format(
                    train_idx.size, valid_idx.size, test_idx.size))
                print(result_path)
                best_rmse = model.train_test_validation(M, train_idx=train_idx, test_idx=test_idx, valid_idx=valid_idx, n_steps=n_steps, result_path=result_path)

                print("Best MSE = {0}".format(best_rmse))

                with open('result.csv', 'a') as f:
                    f.write("{0},{1},{2},{3},{4},{5},{6},{7}\n".format(hidden_encoder_dim, hidden_decoder_dim,
                                                                               latent_dim, learning_rate, batch_size, reg_param, vae, best_rmse))

        tf.reset_default_graph()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号