test_core.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:keras-customized 作者: ambrite 项目源码 文件源码
def test_lambda():
    from keras.utils.layer_utils import layer_from_config
    Lambda = core.Lambda

    layer_test(Lambda,
               kwargs={'function': lambda x: x + 1},
               input_shape=(3, 2))

    layer_test(Lambda,
               kwargs={'function': lambda x, a, b: x * a + b,
                       'arguments': {'a': 0.6, 'b': 0.4}},
               input_shape=(3, 2))

    # test serialization with function
    def f(x):
        return x + 1

    ld = Lambda(f)
    config = ld.get_config()
    ld = layer_from_config({'class_name': 'Lambda', 'config': config})

    ld = Lambda(lambda x: K.concatenate([K.square(x), x]),
                output_shape=lambda s: tuple(list(s)[:-1] + [2 * s[-1]]))
    config = ld.get_config()
    ld = Lambda.from_config(config)

    # test serialization with output_shape function
    def f(x):
        return K.concatenate([K.square(x), x])

    def f_shape(s):
        return tuple(list(s)[:-1] + [2 * s[-1]])

    ld = Lambda(f, output_shape=f_shape)
    config = ld.get_config()
    ld = layer_from_config({'class_name': 'Lambda', 'config': config})
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号