def threshold_weighted_unique_confusion_matrix(y_true, y_pred,
weights, ids, th=0.5):
"""
Computes a weighted event-wise confusion matrix with a threshold in
predictions.
Takes numpy arrays
Arguments:
y_true - labels
y_pred - predictions
weights - weights for each waveform
ids - ids to correlate waveforms with events
th - probability threshold above which the signal class is
considered to predict signal (default: 0.5)
Returns:
confusion_matrix - a numpy array containing the confusion matrix
"""
# This statement flattens vectors from one-hot, thresholds predictions
return weighted_unique_confusion_matrix(y_true.nonzero()[1],
y_pred[:, 1] > th, weights, ids)
评论列表
文章目录