basic_rnn_test.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testNoSizeButAlreadyConnected(self):
    batch_size = 16
    cores = [snt.LSTM(hidden_size=10), snt.Linear(output_size=42), tf.nn.relu]
    rnn = snt.DeepRNN(cores, skip_connections=False)
    unused_output = rnn(tf.zeros((batch_size, 128)),
                        rnn.initial_state(batch_size=batch_size))

    with mock.patch.object(tf.logging, "warning") as mocked_logging_warning:
      output_size = rnn.output_size
      # Correct size is automatically inferred.
      self.assertEqual(output_size, tf.TensorShape([42]))
      self.assertTrue(mocked_logging_warning.called)
      first_call_args = mocked_logging_warning.call_args[0]
      self.assertTrue("DeepRNN has been connected into the graph, "
                      "so inferred output size" in first_call_args[0])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号