def get_mae(pred, actual): # only compute on non-zero terms pred = pred[actual.nonzero()].flatten() actual = actual[actual.nonzero()].flatten() return mean_absolute_error(pred, actual)