test_networks.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:third_person_im 作者: bstadie 项目源码 文件源码
def test_gru_network():
    from rllab.core.network import GRUNetwork
    import lasagne.layers as L
    from rllab.misc import ext
    import numpy as np
    network = GRUNetwork(
        input_shape=(2, 3),
        output_dim=5,
        hidden_dim=4,
    )
    f_output = ext.compile_function(
        inputs=[network.input_layer.input_var],
        outputs=L.get_output(network.output_layer)
    )
    assert f_output(np.zeros((6, 8, 2, 3))).shape == (6, 8, 5)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号