def run_epoch_doc(docs, labels, tags, tm, pad_id, cf):
batches = int(math.ceil(float(len(docs))/cf.batch_size))
accs = []
for b in xrange(batches):
d, y, m, t, num_docs = get_batch_doc(docs, labels, tags, b, cf.doc_len, cf.tag_len, cf.batch_size, pad_id)
prob = sess.run(tm.sup_probs, {tm.doc:d, tm.label:y, tm.sup_mask: m, tm.tag: t})
pred = np.argmax(prob, axis=1)
accs.extend(pred[:num_docs] == y[:num_docs])
print "\ntest classification accuracy = %.3f" % np.mean(accs)
tdlm_test.py 文件源码
python
阅读 30
收藏 0
点赞 0
评论 0
评论列表
文章目录