def load_and_run(args, trainerClass):
start_time = time.time()
seed = int(args.get('--seed', 0))
trainer = load_trainer(args, trainerClass, Data, seed)
train_batch_name = args.get('--train-batch', None) or "train"
validation_batch_name = args.get('--validation-batch', None)
test_batch_name = args.get('--test-batch', None)
print_params = args.get('--print-params', False) or False
print_loss_breakdown = args.get('--print-loss-breakdown', False) or False
num_restarts = int(args.get('--num-restarts', 1))
for i in xrange(num_restarts):
(params, discretized_params) = trainer.train(train_batch_name,
validation_batch_name=validation_batch_name,
test_batch_name=test_batch_name,
print_params=print_params,
print_final_loss_breakdown=print_loss_breakdown)
if '--store-data' in args and args['--store-data'] is not None:
store_results_to_hdf5(args['--store-data'], trainer, train_batch_name, restart_idx=i)
print ("Training stopped after %2.fs." % (time.time() - start_time))
评论列表
文章目录