def accuracy(session, graphs, data_iter, num_threads, train=False):
num_total = 0
num_correct = 0
def process_batch(batch_x, batch_y):
nonlocal num_correct
nonlocal num_total
with graphs.lease() as g:
input_placeholder, output_placeholder, keep_prob_placeholder, train_step_f, num_correct_f, no_op = g
batch_num_correct, _ = session.run(
[num_correct_f, train_step_f if train else no_op],
{
input_placeholder: batch_x,
output_placeholder: batch_y,
keep_prob_placeholder: 0.5 if train else 1.0,
})
num_correct += batch_num_correct
num_total += len(batch_x)
with BlockOnFullThreadPool(max_workers=num_threads, queue_size=num_threads // 2) as pool:
for i, (batch_x, batch_y) in enumerate(data_iter):
pool.submit(process_batch, batch_x, batch_y)
pool.shutdown(wait=True)
return float(num_correct) / float(num_total)
评论列表
文章目录