train.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:Structured-Self-Attentive-Sentence-Embedding 作者: ExplorerFreda 项目源码 文件源码
def evaluate():
    """evaluate the model while training"""
    model.eval()  # turn on the eval() switch to disable dropout
    total_loss = 0
    total_correct = 0
    for batch, i in enumerate(range(0, len(data_val), args.batch_size)):
        data, targets = package(data_val[i:min(len(data_val), i+args.batch_size)], volatile=True)
        if args.cuda:
            data = data.cuda()
            targets = targets.cuda()
        hidden = model.init_hidden(data.size(1))
        output, attention = model.forward(data, hidden)
        output_flat = output.view(data.size(1), -1)
        total_loss += criterion(output_flat, targets).data
        prediction = torch.max(output_flat, 1)[1]
        total_correct += torch.sum((prediction == targets).float())
    return total_loss[0] / (len(data_val) // args.batch_size), total_correct.data[0] / len(data_val)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号