mnist.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:a3c 作者: siemanko 项目源码 文件源码
def accuracy(session, graphs, data_iter, num_threads, train=False):
    num_total   = 0
    num_correct = 0

    def process_batch(batch_x, batch_y):
        nonlocal num_correct
        nonlocal num_total
        with graphs.lease() as g:
            input_placeholder, output_placeholder, keep_prob_placeholder, train_step_f, num_correct_f, no_op = g
            batch_num_correct, _ = session.run(
                [num_correct_f, train_step_f if train else no_op],
                {
                    input_placeholder:     batch_x,
                    output_placeholder:    batch_y,
                    keep_prob_placeholder: 0.5 if train else 1.0,
                })
            num_correct += batch_num_correct
            num_total   += len(batch_x)

    with BlockOnFullThreadPool(max_workers=num_threads, queue_size=num_threads // 2) as pool:
        for i, (batch_x, batch_y) in enumerate(data_iter):
            pool.submit(process_batch, batch_x, batch_y)
        pool.shutdown(wait=True)

    return float(num_correct) / float(num_total)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号