def test_build_nn(build_nn):
with tf.Graph().as_default():
test_input_data_shape = [128, 5]
test_input_data = tf.placeholder(tf.int32, test_input_data_shape)
test_rnn_size = 256
test_rnn_layer_size = 2
test_vocab_size = 27
test_cell = rnn.MultiRNNCell([rnn.BasicLSTMCell(test_rnn_size)] * test_rnn_layer_size)
logits, final_state = build_nn(test_cell, test_rnn_size, test_input_data, test_vocab_size)
# Check name
assert hasattr(final_state, 'name'), \
'Final state doesn\'t have the "name" attribute. Are you using build_rnn?'
assert final_state.name == 'final_state:0', \
'Final state doesn\'t have the correct name. Found the name {}. Are you using build_rnn?'.format(final_state.name)
# Check Shape
assert logits.get_shape().as_list() == test_input_data_shape + [test_vocab_size], \
'Outputs has wrong shape. Found shape {}'.format(logits.get_shape())
assert final_state.get_shape().as_list() == [test_rnn_layer_size, 2, None, test_rnn_size], \
'Final state wrong shape. Found shape {}'.format(final_state.get_shape())
_print_success_message()
problem_unittests.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录