def get_accuracy(data_loader, classifier_fn, batch_size):
"""
compute the accuracy over the supervised training set or the testing set
"""
predictions, actuals = [], []
# use the appropriate data loader
for (xs, ys) in data_loader:
# use classification function to compute all predictions for each batch
xs, ys = Variable(xs), Variable(ys)
predictions.append(classifier_fn(xs))
actuals.append(ys)
# compute the number of accurate predictions
accurate_preds = 0
for pred, act in zip(predictions, actuals):
for i in range(pred.size(0)):
v = torch.sum(pred[i] == act[i])
accurate_preds += (v.data[0] == 10)
# calculate the accuracy between 0 and 1
accuracy = (accurate_preds * 1.0) / (len(predictions) * batch_size)
return accuracy
评论列表
文章目录