def test_get_model(self):
"""Just make sure we can get a model without errors"""
# TODO(pfcm) nice helpers for setting up/tearing down a graph & sess
with tf.Graph().as_default():
inputs = tf.placeholder(tf.float32, [50, 30, 10])
cell = tf.nn.rnn_cell.BasicRNNCell(32)
istate, logits, fstate = ns.standard_nextstep_inference(
cell, inputs, 5)
# check shapes are as expected
self.assertEqual(istate[0].get_shape().as_list(),
[30, 32])
self.assertEqual(len(logits), 50)
self.assertEqual(logits[0].get_shape().as_list(),
[30, 5])
self.assertEqual(istate[0].get_shape().as_list(),
fstate[0].get_shape().as_list())
评论列表
文章目录