def evaluate():
"""evaluate the model while training"""
model.eval() # turn on the eval() switch to disable dropout
total_loss = 0
total_correct = 0
for batch, i in enumerate(range(0, len(data_val), args.batch_size)):
data, targets = package(data_val[i:min(len(data_val), i+args.batch_size)], volatile=True)
if args.cuda:
data = data.cuda()
targets = targets.cuda()
hidden = model.init_hidden(data.size(1))
output, attention = model.forward(data, hidden)
output_flat = output.view(data.size(1), -1)
total_loss += criterion(output_flat, targets).data
prediction = torch.max(output_flat, 1)[1]
total_correct += torch.sum((prediction == targets).float())
return total_loss[0] / (len(data_val) // args.batch_size), total_correct.data[0] / len(data_val)
train.py 文件源码
python
阅读 36
收藏 0
点赞 0
评论 0
评论列表
文章目录