def predict_human_readable (dataset_loader, simple_dataset, model, outdir, top_k):
model.eval()
print "predicting..."
mx = len(dataset_loader)
for i, (input, index) in enumerate(dataset_loader):
print "{}/{} batches".format(i+1,mx)
input_var = torch.autograd.Variable(input.cuda(), volatile = True)
(scores,predictions) = model.forward_max(input_var)
#(s_sorted, idx) = torch.sort(scores, 1, True)
human = encoder.to_situation(predictions)
(b,p,d) = predictions.size()
for _b in range(0,b):
items = []
offset = _b *p
for _p in range(0, p):
items.append(human[offset + _p])
items[-1]["score"] = scores.data[_b][_p]
items = sorted(items, key = lambda x: -x["score"])[:top_k]
name = simple_dataset.images[index[_b][0]].split(".")[:-1]
name.append("predictions")
outfile = outdir + ".".join(name)
json.dump(items,open(outfile,"w"))
评论列表
文章目录