def predict_test_file(preds, sess, test_file, feature_cnt, _indices, _values, _values2, _cont_values, _text_values, _shape,
_cont_shape, _text_shape, _y, _ind, epoch, batch_size, tag, path, output_prediction=True):
day = date.today()
if output_prediction:
wt = open(path + '/'+str(day)+'_deepFM_pred_' + tag + str(epoch) + '.txt', 'w')
gt_scores = []
pred_scores = []
for test_input_in_sp in load_data_cache(test_file):
predictios = sess.run(preds, feed_dict={
_indices: test_input_in_sp['indices'], _values: test_input_in_sp['values'],
_shape: test_input_in_sp['shape'], _cont_shape: test_input_in_sp['cont_shape'],
_text_values: test_input_in_sp['text_values'], _text_shape: test_input_in_sp['text_shape'],
_y: test_input_in_sp['labels'], _values2: test_input_in_sp['values2'],
_cont_values: test_input_in_sp['cont_values'], _ind: test_input_in_sp['feature_indices']
}).reshape(-1).tolist()
if output_prediction:
for (gt, preded) in zip(test_input_in_sp['labels'].reshape(-1).tolist(), predictios):
wt.write('{0:d},{1:f}\n'.format(int(gt), preded))
gt_scores.append(gt)
# pred_scores.append(1.0 if preded >= 0.5 else 0.0)
pred_scores.append(preded)
else:
gt_scores.extend(test_input_in_sp['labels'].reshape(-1).tolist())
pred_scores.extend(predictios)
auc = metrics.roc_auc_score(np.asarray(gt_scores), np.asarray(pred_scores))
logloss = metrics.log_loss(np.asarray(gt_scores), np.asarray(pred_scores))
# print('auc is ', auc, ', at epoch ', epoch)
if output_prediction:
wt.close()
return auc, logloss
评论列表
文章目录