def get_threshold(model_id):
trained_models = pd.read_csv(common.DEFAULT_TRAINED_MODELS_FILE, sep='\t')
model_config = trained_models[trained_models["model_id"] == model_id]
if model_config.empty:
raise ValueError("Can't find the model %s in %s" %
(model_id, common.DEFAULT_TRAINED_MODELS_FILE))
model_config = model_config.to_dict(orient="list")
model_settings=eval(model_config['dataset_settings'][0])
Y_test = np.load(common.DATASETS_DIR+'/item_factors_test_%s_%s_%s.npy' % (model_settings['fact'],model_settings['dim'],model_settings['dataset']))
Y_pred = np.load(common.FACTORS_DIR+'/factors_%s.npy' % model_id)
good_scores = Y_pred[Y_test==1]
th = good_scores.mean()
std = good_scores.std()
print 'Mean th',th
print 'Std',std
p, r, thresholds = precision_recall_curve(Y_test.flatten(), Y_pred.flatten())
f = np.nan_to_num((2 * (p*r) / (p+r)) * (p>r))
print f
max_f = np.argmax(f)
fth = thresholds[max_f]
print f[max_f],p[max_f],r[max_f]
print 'F th %.2f' % fth
plt.plot(r, p,
label='Precision-recall curve of class {0}')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Extension of Precision-Recall curve to multi-class')
plt.savefig("pr_curve.png")
评论列表
文章目录