def test_get_init_cell(get_init_cell):
with tf.Graph().as_default():
test_batch_size_ph = tf.placeholder(tf.int32)
test_rnn_size = 256
cell, init_state = get_init_cell(test_batch_size_ph, test_rnn_size)
# Check type
assert isinstance(cell, tf.contrib.rnn.MultiRNNCell),\
'Cell is wrong type. Found {} type'.format(type(cell))
# Check for name attribute
assert hasattr(init_state, 'name'),\
'Initial state doesn\'t have the "name" attribute. Try using `tf.identity` to set the name.'
# Check name
assert init_state.name == 'initial_state:0',\
'Initial state doesn\'t have the correct name. Found the name {}'.format(init_state.name)
_print_success_message()
problem_unittests.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录