test_keras_backend.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:python-alp 作者: tboquet 项目源码 文件源码
def test_build_predict_func(self, get_model):
        """Test the build of a model"""
        new_session()
        X_tr = np.ones((train_samples, input_dim))
        model = get_model()
        model.compile(loss='categorical_crossentropy',
                      optimizer='rmsprop',
                      metrics=['accuracy'])

        model_name = model.__class__.__name__

        pred_func = KTB.build_predict_func(model)

        tensors = [X_tr]
        if model_name != 'Model':
            tensors.append(1.)

        res = pred_func(tensors)

        assert len(res[0]) == len(X_tr)

        if K.backend() == 'tensorflow':
            K.clear_session()

        print(self)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号