basic_rnn_test.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testMLPFinalCore(self):
    batch_size = 2
    sequence_length = 3
    input_size = 4
    mlp_last_layer_size = 17
    cores = [
        snt.LSTM(hidden_size=10),
        snt.nets.MLP(output_sizes=[6, 7, mlp_last_layer_size]),
    ]
    deep_rnn = snt.DeepRNN(cores, skip_connections=False)
    input_sequence = tf.constant(
        np.random.randn(sequence_length, batch_size, input_size),
        dtype=tf.float32)
    initial_state = deep_rnn.initial_state(batch_size=batch_size)
    output, unused_final_state = tf.nn.dynamic_rnn(
        deep_rnn, input_sequence,
        initial_state=initial_state,
        time_major=True)
    self.assertEqual(
        output.get_shape(),
        tf.TensorShape([sequence_length, batch_size, mlp_last_layer_size]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号