problem_unittests.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:deep-learning-nd 作者: RyanCCollins 项目源码 文件源码
def test_build_nn(build_nn):
    with tf.Graph().as_default():
        test_input_data_shape = [128, 5]
        test_input_data = tf.placeholder(tf.int32, test_input_data_shape)
        test_rnn_size = 256
        test_rnn_layer_size = 2
        test_vocab_size = 27
        test_cell = rnn.MultiRNNCell([rnn.BasicLSTMCell(test_rnn_size)] * test_rnn_layer_size)

        logits, final_state = build_nn(test_cell, test_rnn_size, test_input_data, test_vocab_size)

        # Check name
        assert hasattr(final_state, 'name'), \
            'Final state doesn\'t have the "name" attribute.  Are you using build_rnn?'
        assert final_state.name == 'final_state:0', \
            'Final state doesn\'t have the correct name. Found the name {}. Are you using build_rnn?'.format(final_state.name)

        # Check Shape
        assert logits.get_shape().as_list() == test_input_data_shape + [test_vocab_size], \
            'Outputs has wrong shape.  Found shape {}'.format(logits.get_shape())
        assert final_state.get_shape().as_list() == [test_rnn_layer_size, 2, None, test_rnn_size], \
            'Final state wrong shape.  Found shape {}'.format(final_state.get_shape())

    _print_success_message()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号