def score(net, samples=4096):
"""Compute the area under the curve, ROC score from a trained net
We take `samples` random samples and compute the ROC AUC
score on those samples.
"""
source = net.batch_iterator_test.source
test_indices = make_valid_indices(source, samples)
predicted = net.predict_proba(test_indices)
if predicted.shape[-1] != N_EVENTS:
predicted = decode(predicted)
actual = source.events[test_indices]
try:
return roc_auc_score(actual.reshape(-1), predicted.reshape(-1))
except:
return 0
评论列表
文章目录