test.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:supervised-embedding-model 作者: sld 项目源码 文件源码
def evaluate_one_row(candidates_tensor, true_context, sess, model, test_score, true_response):
    for batch in batch_iter(candidates_tensor, 512):
        candidate_responses = batch[:, 0, :]
        context_batch = np.repeat(true_context, candidate_responses.shape[0], axis=0)

        scores = sess.run(
            model.f_pos,
            feed_dict={model.context_batch: context_batch,
                       model.response_batch: candidate_responses,
                       model.neg_response_batch: candidate_responses}
        )
        for ind, score in enumerate(scores):
            if score == float('Inf') or score == -float('Inf') or score == float('NaN'):
                print(score, ind, scores[ind])
                raise ValueError
            if score >= test_score and not np.array_equal(candidate_responses[ind], true_response):
                return False
    return True
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号