test_pipeline.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:human-rl 作者: gsastry 项目源码 文件源码
def train_classifier(test, blocker=False):

    number_train=20
    number_valid=30
    number_test=25

    steps = 1000
    batch_size= 1024
    conv_layers = 3

    if test:
        number_train=2
        number_valid=2
        number_test=2
        steps = 50
        batch_size = 20
        conv_layers = 2

    multiprocessing.freeze_support()

    episode_paths = frame.episode_paths(input_path)
    print('Found {} episodes'.format(len(episode_paths)))
    np.random.seed(seed=42)
    np.random.shuffle(episode_paths)

    if blocker:
        common_hparams = dict(use_action=True,  expected_positive_weight=0.05)
        labeller = humanrl.pong_catastrophe.PongBlockerLabeller()
    else:
        common_hparams = dict(use_action=False)
        labeller = humanrl.pong_catastrophe.PongClassifierLabeller()

    data_loader = DataLoader(labeller, TensorflowClassifierHparams(**common_hparams))
    datasets = data_loader.split_episodes(episode_paths,
                                          number_train, number_valid, number_test, use_all=False)


    hparams_list = [
        dict(image_crop_region=((34,34+160),(0,160)), #image_shape=[42, 42, 1], 
             convolution2d_stack_args=[(4, [3, 3], [2, 2])] * conv_layers, batch_size=batch_size, multiprocess=False,
             fully_connected_stack_args=[50,10],
             use_observation=False, use_image=True,
             verbose=True
         ) 
    ]

    start_experiment = time.time()
    print('Run experiment params: ', dict(number_train=number_train, number_valid=number_valid,
                                          number_test=number_test, steps=steps, batch_size=batch_size,
                                          conv_layers=conv_layers) )
    print('hparams', common_hparams, hparams_list[0])


    logdir = save_classifier_path
    run_experiments(
        logdir, data_loader, datasets, common_hparams, hparams_list, steps=steps, log_every=int(.1*steps))

    time_experiment = time.time() - start_experiment
    print('Steps: {}. Time in mins: {}'.format(steps, (1/60)*time_experiment))

    run_classifier_metrics()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号