precompute_probs.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:instacart-basket-prediction 作者: colinmorris 项目源码 文件源码
def main():
  logging.basicConfig(level=logging.INFO)
  parser = argparse.ArgumentParser()
  parser.add_argument('tags', metavar='tag', nargs='+')
  parser.add_argument('--fold', default='test', 
      help='identifier for file with the users to test on (default: test)')
  args = parser.parse_args()


  for model_tag in args.tags:
    hps = hypers.hps_for_tag(model_tag)
    dataset = Dataset(args.fold, hps, mode=Mode.inference)
    path = common.resolve_xgboostmodel_path(model_tag)
    logging.info('Loading model with tag {}'.format(model_tag))
    model = xgb.Booster(model_file=path)
    logging.info('Computing probs for tag {}'.format(model_tag))
    with time_me('Computed probs for {}'.format(model_tag), mode='stderr'):
      pdict = get_pdict(model, dataset)
      logging.info('Got probs for {} users'.format(len(pdict)))
      # TODO: might want to enforce some namespace separation between 
      # rnn-generated pdicts and ones coming from xgboost models?
      common.save_pdict_for_tag(model_tag, pdict, args.fold)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号