def predict(dataset_iter=test_iter, dataset=test, data_name="test"):
print("Dataset: {}".format(data_name))
model.eval()
dataset_iter.init_epoch()
n_correct = 0
fname = "{}.txt".format(data_name)
results_file = open(os.path.join(results_path, fname), 'w')
n_retrieved = 0
fid = open(os.path.join(args.data_dir,"lineids_{}.txt".format(data_name)))
sent_id = [x.strip() for x in fid.readlines()]
for data_batch_idx, data_batch in enumerate(dataset_iter):
scores = model(data_batch)
if args.dataset == 'RelationPrediction':
n_correct += (torch.max(scores, 1)[1].view(data_batch.relation.size()).data == data_batch.relation.data).sum()
# Get top k
top_k_scores, top_k_indices = torch.topk(scores, k=args.hits, dim=1, sorted=True) # shape: (batch_size, k)
top_k_scores_array = top_k_scores.cpu().data.numpy()
top_k_indices_array = top_k_indices.cpu().data.numpy()
top_k_relatons_array = index2tag[top_k_indices_array]
for i, (relations_row, scores_row) in enumerate(zip(top_k_relatons_array, top_k_scores_array)):
index = (data_batch_idx * args.batch_size) + i
example = data_batch.dataset.examples[index]
for j, (rel, score) in enumerate(zip(relations_row, scores_row)):
if (rel == example.relation):
label = 1
n_retrieved += 1
else:
label = 0
results_file.write(
"{} %%%% {} %%%% {} %%%% {}\n".format( sent_id[index], rel, label, score))
else:
print("Wrong Dataset")
exit()
if args.dataset == 'RelationPrediction':
P = 1. * n_correct / len(dataset)
print("{} Precision: {:10.6f}%".format(data_name, 100. * P))
print("no. retrieved: {} out of {}".format(n_retrieved, len(dataset)))
retrieval_rate = 100. * n_retrieved / len(dataset)
print("{} Retrieval Rate {:10.6f}".format(data_name, retrieval_rate))
else:
print("Wrong dataset")
exit()
# run the model on the dev set and write the output to a file
评论列表
文章目录