train.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:TerpreT 作者: 51alg 项目源码 文件源码
def load_and_run(args, trainerClass):
    start_time = time.time()
    seed = int(args.get('--seed', 0))
    trainer = load_trainer(args, trainerClass, Data, seed)
    train_batch_name = args.get('--train-batch', None) or "train"
    validation_batch_name = args.get('--validation-batch', None)
    test_batch_name = args.get('--test-batch', None)
    print_params = args.get('--print-params', False) or False
    print_loss_breakdown = args.get('--print-loss-breakdown', False) or False
    num_restarts = int(args.get('--num-restarts', 1))


    for i in xrange(num_restarts):
        (params, discretized_params) = trainer.train(train_batch_name,
                                                     validation_batch_name=validation_batch_name,
                                                     test_batch_name=test_batch_name,
                                                     print_params=print_params,
                                                     print_final_loss_breakdown=print_loss_breakdown)

        if '--store-data' in args and args['--store-data'] is not None:
            store_results_to_hdf5(args['--store-data'], trainer, train_batch_name, restart_idx=i)
        print ("Training stopped after %2.fs." % (time.time() - start_time))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号