def main():
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument('tag')
parser.add_argument('--train-recordfile', default='train',
help='identifier for file with the users to train on (default: train). deprecated: specify in hps...')
parser.add_argument('-n', '--n-rounds', type=int, default=50,
help='Number of rounds of boosting. Deprecated: specify this in hp config file')
parser.add_argument('--weight', action='store_true',
help='Whether to do per-instance weighting. Deprecated: specify in hps')
args = parser.parse_args()
try:
hps = hypers.hps_for_tag(args.tag)
except hypers.NoHpsDefinedException:
logging.warn('No hps found for tag {}. Creating and saving some.'.format(args.tag))
hps = hypers.get_default_hparams()
hps.train_file = args.train_recordfile
hps.rounds = args.n_rounds
hps.weight = args.weight
hypers.save_hps(args.tag, hps)
validate_hps(hps)
dataset = Dataset(hps.train_file, hps)
with time_me(mode='stderr'):
train(dataset, args.tag, hps)
评论列表
文章目录