mnist.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:a3c 作者: siemanko 项目源码 文件源码
def main(args):
    with tf.device("cpu"):
        data = Data(batch_size=args.batch_size, validation_size=6000)

        session = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=args.num_threads))
        graphs = SharedResource([build_graph(reuse=i > 0) for i in range(args.num_threads)])

        session.run(tf.initialize_all_variables())

        train_total_time_sum = 0
        for epoch in range(args.num_epochs):
            train_start_time = time.time()
            train_accuracy    = accuracy(session, graphs, data.iterate_train(), num_threads=args.num_threads, train=True)
            train_total_time = time.time() - train_start_time
            train_total_time_sum += train_total_time

            validate_accuracy = accuracy(session, graphs, data.iterate_validate(), num_threads=args.num_threads, train=False)

            print ("Training epoch number %d:" % (epoch,))
            print ("    Time to train           = %.3f s" % (train_total_time))
            print ("    Training set accuracy   = %.1f %%" % (100.0 * train_accuracy,))
            print ("    Validation set accuracy = %.1f %%" % (100.0 * validate_accuracy,))
            print ("")
        print ("Training done.")

        test_accuracy = accuracy(session, graphs, data.iterate_test(), num_threads=args.num_threads, train=False)
        print ("    Average time per training epoch = %.3f s" % (train_total_time_sum / NUM_EPOCHS,))
        print ("    Test set accuracy               = %.1f %%" % (100.0 * test_accuracy,))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号