deepFM.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:DeepFM 作者: dwt0317 项目源码 文件源码
def predict_test_file(preds, sess, test_file, feature_cnt, _indices, _values, _values2, _cont_values, _text_values, _shape,
                      _cont_shape, _text_shape, _y, _ind, epoch, batch_size, tag, path, output_prediction=True):
    day = date.today()
    if output_prediction:
        wt = open(path + '/'+str(day)+'_deepFM_pred_' + tag + str(epoch) + '.txt', 'w')

    gt_scores = []
    pred_scores = []

    for test_input_in_sp in load_data_cache(test_file):
        predictios = sess.run(preds, feed_dict={
            _indices: test_input_in_sp['indices'], _values: test_input_in_sp['values'],
            _shape: test_input_in_sp['shape'], _cont_shape: test_input_in_sp['cont_shape'],
            _text_values: test_input_in_sp['text_values'], _text_shape: test_input_in_sp['text_shape'],

            _y: test_input_in_sp['labels'], _values2: test_input_in_sp['values2'],
            _cont_values: test_input_in_sp['cont_values'], _ind: test_input_in_sp['feature_indices']
        }).reshape(-1).tolist()

        if output_prediction:
            for (gt, preded) in zip(test_input_in_sp['labels'].reshape(-1).tolist(), predictios):
                wt.write('{0:d},{1:f}\n'.format(int(gt), preded))
                gt_scores.append(gt)
                # pred_scores.append(1.0 if preded >= 0.5 else 0.0)
                pred_scores.append(preded)
        else:
            gt_scores.extend(test_input_in_sp['labels'].reshape(-1).tolist())
            pred_scores.extend(predictios)
    auc = metrics.roc_auc_score(np.asarray(gt_scores), np.asarray(pred_scores))
    logloss = metrics.log_loss(np.asarray(gt_scores), np.asarray(pred_scores))
    # print('auc is ', auc, ', at epoch  ', epoch)
    if output_prediction:
        wt.close()
    return auc, logloss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号