train_model.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:ROCStory_skipthought_baseline 作者: soskek 项目源码 文件源码
def evaluate(dataset, model, args):
    sum_correct = 0.
    sum_loss_data = xp.zeros(())
    for i in six.moves.range(0, len(dataset), args.batchsize):
        x_batch_seq = make_batch([dataset[i + j:i + j + 1]
                                  for j in range(args.batchsize)], train=False)
        x_batch_seq, pos, neg = x_batch_seq[:4], x_batch_seq[4], x_batch_seq[5]
        loss, correct = model.solve(
            x_batch_seq, pos, neg, train=False, variablize=True)
        sum_loss_data += loss.data
        sum_correct += correct
    return cuda.to_cpu(sum_loss_data) / len(dataset), sum_correct
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号