baseline_crf.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:imSitu 作者: my89 项目源码 文件源码
def predict_human_readable (dataset_loader, simple_dataset,  model, outdir, top_k):
  model.eval()  
  print "predicting..." 
  mx = len(dataset_loader) 
  for i, (input, index) in enumerate(dataset_loader):
      print "{}/{} batches".format(i+1,mx)
      input_var = torch.autograd.Variable(input.cuda(), volatile = True)
      (scores,predictions)  = model.forward_max(input_var)
      #(s_sorted, idx) = torch.sort(scores, 1, True)
      human = encoder.to_situation(predictions)
      (b,p,d) = predictions.size()
      for _b in range(0,b):
        items = []
        offset = _b *p
        for _p in range(0, p):
          items.append(human[offset + _p])
          items[-1]["score"] = scores.data[_b][_p]
        items = sorted(items, key = lambda x: -x["score"])[:top_k]
        name = simple_dataset.images[index[_b][0]].split(".")[:-1]
        name.append("predictions")
        outfile = outdir + ".".join(name)
        json.dump(items,open(outfile,"w"))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号