tdlm_test.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:topically-driven-language-model 作者: jhlau 项目源码 文件源码
def run_epoch_doc(docs, labels, tags, tm, pad_id, cf):
    batches = int(math.ceil(float(len(docs))/cf.batch_size))
    accs = []
    for b in xrange(batches):
        d, y, m, t, num_docs = get_batch_doc(docs, labels, tags, b, cf.doc_len, cf.tag_len, cf.batch_size, pad_id)
        prob = sess.run(tm.sup_probs, {tm.doc:d, tm.label:y, tm.sup_mask: m, tm.tag: t})
        pred = np.argmax(prob, axis=1)
        accs.extend(pred[:num_docs] == y[:num_docs])

    print "\ntest classification accuracy = %.3f" % np.mean(accs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号