train.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:instacart-basket-prediction 作者: colinmorris 项目源码 文件源码
def main():
  baskets.time_me.set_default_mode('print')
  logging.basicConfig(level=logging.INFO)
  parser = argparse.ArgumentParser()
  parser.add_argument('tags', nargs='+')
  parser.add_argument('-f', '--train-fold', default='train')
  parser.add_argument('--validation-fold', help='Fold for validation (default: None)')
  parser.add_argument('--no-metafeats', action='store_true')
  parser.add_argument('--svm', action='store_true')
  args = parser.parse_args()

  with time_me("Loaded metavectors"):
    meta_df = pd.read_pickle(METAVECTORS_PICKLEPATH)

  with time_me("Made training vectors"):
    X, y = vectorize_fold(args.train_fold, args.tags, meta_df, use_metafeats=not args.no_metafeats)

  # This sucks.
  if args.svm:
    # slooow :( (sklearn docs say hard to scale to dataset w more than like 20k examples)
    #model = sklearn.svm.SVC(verbose=True, probability=True, C=1.0)
    model = sklearn.svm.LinearSVC( penalty='l2', loss='hinge', C=.001, verbose=1,)
  else:
    # TODO: C
    model = LogisticRegression(verbose=1)
  with time_me('Trained model', mode='print'):
    model.fit(X, y)

  model_fname = 'model.pkl'
  joblib.dump(model, model_fname)
  return model
  # TODO: report acc on validation set
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号