train.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:instacart-basket-prediction 作者: colinmorris 项目源码 文件源码
def main():
  logging.basicConfig(level=logging.INFO)
  parser = argparse.ArgumentParser()
  parser.add_argument('tag')
  parser.add_argument('--train-recordfile', default='train', 
      help='identifier for file with the users to train on (default: train). deprecated: specify in hps...')
  parser.add_argument('-n', '--n-rounds', type=int, default=50,
      help='Number of rounds of boosting. Deprecated: specify this in hp config file')
  parser.add_argument('--weight', action='store_true',
      help='Whether to do per-instance weighting. Deprecated: specify in hps')
  args = parser.parse_args()

  try:
    hps = hypers.hps_for_tag(args.tag)
  except hypers.NoHpsDefinedException:
    logging.warn('No hps found for tag {}. Creating and saving some.'.format(args.tag))
    hps = hypers.get_default_hparams()
    hps.train_file = args.train_recordfile
    hps.rounds = args.n_rounds
    hps.weight = args.weight
    hypers.save_hps(args.tag, hps)
  validate_hps(hps)
  dataset = Dataset(hps.train_file, hps)
  with time_me(mode='stderr'):
    train(dataset, args.tag, hps)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号