test_backends.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:keras-customized 作者: ambrite 项目源码 文件源码
def check_composed_tensor_operations(first_function_name, first_function_args,
                                     second_function_name, second_function_args,
                                     input_shape):
    ''' Creates a random tensor t0 with shape input_shape and compute
                 t1 = first_function_name(t0, **first_function_args)
                 t2 = second_function_name(t1, **second_function_args)
        with both Theano and TensorFlow backends and ensures the answers match.
    '''
    val = np.random.random(input_shape) - 0.5
    xth = KTH.variable(val)
    xtf = KTF.variable(val)

    yth = getattr(KTH, first_function_name)(xth, **first_function_args)
    ytf = getattr(KTF, first_function_name)(xtf, **first_function_args)

    zth = KTH.eval(getattr(KTH, second_function_name)(yth, **second_function_args))
    ztf = KTF.eval(getattr(KTF, second_function_name)(ytf, **second_function_args))

    assert zth.shape == ztf.shape
    assert_allclose(zth, ztf, atol=1e-05)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号