test_model.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:ntm-one-shot 作者: tristandeleu 项目源码 文件源码
def test_shape():
    input_var = T.tensor3('input')
    target_var = T.imatrix('target')
    output_var, _, _ = memory_augmented_neural_network(
        input_var, target_var,
        batch_size=16,
        nb_class=5,
        memory_shape=(128, 40),
        controller_size=200,
        input_size=20 * 20,
        nb_reads=4)

    posterior_fn = theano.function([input_var, target_var], output_var)

    test_input = np.random.rand(16, 50, 20 * 20)
    test_target = np.random.randint(5, size=(16, 50)).astype('int32')
    test_input_invalid_batch_size = np.random.rand(16 + 1, 50, 20 * 20)
    test_input_invalid_depth = np.random.rand(16, 50, 20 * 20 - 1)
    test_output = posterior_fn(test_input, test_target)

    assert test_output.shape == (16, 50, 5)
    with pytest.raises(ValueError) as e_info:
        posterior_fn(test_input_invalid_batch_size, test_target)
    with pytest.raises(ValueError) as e_info:
        posterior_fn(test_input_invalid_depth, test_target)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号