def main(args):
with tf.device("cpu"):
data = Data(batch_size=args.batch_size, validation_size=6000)
session = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=args.num_threads))
graphs = SharedResource([build_graph(reuse=i > 0) for i in range(args.num_threads)])
session.run(tf.initialize_all_variables())
train_total_time_sum = 0
for epoch in range(args.num_epochs):
train_start_time = time.time()
train_accuracy = accuracy(session, graphs, data.iterate_train(), num_threads=args.num_threads, train=True)
train_total_time = time.time() - train_start_time
train_total_time_sum += train_total_time
validate_accuracy = accuracy(session, graphs, data.iterate_validate(), num_threads=args.num_threads, train=False)
print ("Training epoch number %d:" % (epoch,))
print (" Time to train = %.3f s" % (train_total_time))
print (" Training set accuracy = %.1f %%" % (100.0 * train_accuracy,))
print (" Validation set accuracy = %.1f %%" % (100.0 * validate_accuracy,))
print ("")
print ("Training done.")
test_accuracy = accuracy(session, graphs, data.iterate_test(), num_threads=args.num_threads, train=False)
print (" Average time per training epoch = %.3f s" % (train_total_time_sum / NUM_EPOCHS,))
print (" Test set accuracy = %.1f %%" % (100.0 * test_accuracy,))
评论列表
文章目录