baseline_crf.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:imSitu 作者: my89 项目源码 文件源码
def eval_model(dataset_loader, encoding, model):
    model.eval()
    print "evaluating model..."
    top1 = imSituTensorEvaluation(1, 3, encoding)
    top5 = imSituTensorEvaluation(5, 3, encoding)

    mx = len(dataset_loader) 
    for i, (index, input, target) in enumerate(dataset_loader):
      print "{}/{} batches\r".format(i+1,mx) ,
      input_var = torch.autograd.Variable(input.cuda(), volatile = True)
      target_var = torch.autograd.Variable(target.cuda(), volatile = True)
      (scores,predictions)  = model.forward_max(input_var)
      (s_sorted, idx) = torch.sort(scores, 1, True)
      top1.add_point(target, predictions.data, idx.data)
      top5.add_point(target, predictions.data, idx.data)

    print "\ndone."
    return (top1, top5)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号