def main():
baskets.time_me.set_default_mode('print')
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument('tags', nargs='+')
parser.add_argument('-f', '--train-fold', default='train')
parser.add_argument('--validation-fold', help='Fold for validation (default: None)')
parser.add_argument('--no-metafeats', action='store_true')
parser.add_argument('--svm', action='store_true')
args = parser.parse_args()
with time_me("Loaded metavectors"):
meta_df = pd.read_pickle(METAVECTORS_PICKLEPATH)
with time_me("Made training vectors"):
X, y = vectorize_fold(args.train_fold, args.tags, meta_df, use_metafeats=not args.no_metafeats)
# This sucks.
if args.svm:
# slooow :( (sklearn docs say hard to scale to dataset w more than like 20k examples)
#model = sklearn.svm.SVC(verbose=True, probability=True, C=1.0)
model = sklearn.svm.LinearSVC( penalty='l2', loss='hinge', C=.001, verbose=1,)
else:
# TODO: C
model = LogisticRegression(verbose=1)
with time_me('Trained model', mode='print'):
model.fit(X, y)
model_fname = 'model.pkl'
joblib.dump(model, model_fname)
return model
# TODO: report acc on validation set
评论列表
文章目录