def testNoSizeButAlreadyConnected(self):
batch_size = 16
cores = [snt.LSTM(hidden_size=10), snt.Linear(output_size=42), tf.nn.relu]
rnn = snt.DeepRNN(cores, skip_connections=False)
unused_output = rnn(tf.zeros((batch_size, 128)),
rnn.initial_state(batch_size=batch_size))
with mock.patch.object(tf.logging, "warning") as mocked_logging_warning:
output_size = rnn.output_size
# Correct size is automatically inferred.
self.assertEqual(output_size, tf.TensorShape([42]))
self.assertTrue(mocked_logging_warning.called)
first_call_args = mocked_logging_warning.call_args[0]
self.assertTrue("DeepRNN has been connected into the graph, "
"so inferred output size" in first_call_args[0])
评论列表
文章目录