test_keras_backend.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:python-alp 作者: tboquet 项目源码 文件源码
def test_experiment_fit_gen_async(self, get_model, get_loss_metric,
                                      get_custom_l):
        new_session()
        model, metrics, cust_objects = prepare_model(get_model(get_custom_l),
                                                     get_loss_metric,
                                                     get_custom_l)

        _, data_val_use = make_data(train_samples, test_samples)
        expe = Experiment(model)

        expected_value = 2
        for val in [None, 1, data_val_use]:
            gen, data, data_stream = make_gen(batch_size)
            if val == 1:
                val, data_2, data_stream_2 = make_gen(batch_size)
            _, thread = expe.fit_gen_async([gen], [val], nb_epoch=2,
                                           model=model,
                                           metrics=metrics,
                                           custom_objects=cust_objects,
                                           samples_per_epoch=64,
                                           nb_val_samples=128,
                                           verbose=2, overwrite=True)

            thread.join()

            for k in expe.full_res['metrics']:
                if 'iter' not in k:
                    assert len(
                        expe.full_res['metrics'][k]) == expected_value

            close_gens(gen, data, data_stream)
            if val == 1:
                close_gens(val, data_2, data_stream_2)

        if K.backend() == 'tensorflow':
            K.clear_session()

        print(self)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号