train.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:instacart-basket-prediction 作者: colinmorris 项目源码 文件源码
def train(traindat, tag, hps):
  valdat = Dataset('validation', hps, mode=Mode.eval)
  # TODO: try set_base_margin (https://github.com/dmlc/xgboost/blob/master/demo/guide-python/boost_from_prediction.py)
  with time_me('Made training dmatrix', mode='stderr'):
    dtrain = traindat.as_dmatrix()
  def quick_fscore(preds, _notused_dtrain):
    global counter
    counter += 1
    if 0 and counter % 5 != 0:
      return 'fscore', 0.0
    with time_me('calculated validation fscore', mode='print'):
      user_counts = defaultdict(lambda : dict(tpos=0, fpos=0, fneg=0))
      uids = valdat.uids
      labels = dval.get_label()
      for i, prob in enumerate(preds):
        uid = uids[i]
        pred = prob >= THRESH
        label = labels[i]
        if pred and label:
          user_counts[uid]['tpos'] += 1
        elif pred and not label:
          user_counts[uid]['fpos'] += 1
        elif label and not pred:
          user_counts[uid]['fneg'] += 1
      fscore_sum = 0
      for uid, res in user_counts.iteritems():
        numerator = 2 * res['tpos']
        denom = numerator + res['fpos'] + res['fneg']
        if denom == 0:
          fscore = 1
        else:
          fscore = numerator / denom
        fscore_sum += fscore
      return 'fscore', fscore_sum / len(user_counts)

  dval = valdat.as_dmatrix()
  # If you pass in more than one value to evals, early stopping uses the
  # last one. Because why not.
  watchlist = [(dtrain, 'train'), (dval, 'validation'),]
  #watchlist = [(dval, 'validation'),]

  xgb_params = hypers.xgb_params_from_hps(hps)
  evals_result = {}
  t0 = time.time()
  model = xgb.train(xgb_params, dtrain, hps.rounds, evals=watchlist, 
      early_stopping_rounds=hps.early_stopping_rounds, evals_result=evals_result) #, feval=quick_fscore, maximize=True)

  t1 = time.time()
  model_path = common.resolve_xgboostmodel_path(tag)
  model.save_model(model_path)
  preds = model.predict(dval)
  _, fscore = quick_fscore(preds, None)
  logging.info('Final validation (quick) fscore = {}'.format(fscore))
  resultsdict = dict(fscore=fscore, evals=evals_result, duration=t1-t0)
  res_path = os.path.join(common.XGBOOST_DIR, 'results', tag+'.pickle')
  with open(res_path, 'w') as f:
    pickle.dump(resultsdict, f)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号