Tensorflow LSTM中的c_state和m_state是什么?

发布于 2021-01-29 17:04:21

Tensorflow r0.12的tf.nn.rnn_cell.LSTMCell文档将其描述为init:

tf.nn.rnn_cell.LSTMCell.__call__(inputs, state, scope=None)

其中state如下:

state:如果state_is_tuple为False,则必须为状态Tensor,二维,批处理x
state_size。如果state_is_tuple为True,则它必须是列大小为c_state和m_state的二维状态Tensors的元组。

什么区域c_statem_state它们如何适合LSTM?我在文档的任何地方都找不到对它们的引用。

这是文档中该页面的链接。

关注者
0
被浏览
167
1 个回答
  • 面试哥
    面试哥 2021-01-29
    为面试而生,有面试问题,就找面试哥。

    我偶然发现了同样的问题,这就是我的理解方式!简约的LSTM示例:

    import tensorflow as tf
    
    sample_input = tf.constant([[1,2,3]],dtype=tf.float32)
    
    LSTM_CELL_SIZE = 2
    
    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_CELL_SIZE, state_is_tuple=True)
    state = (tf.zeros([1,LSTM_CELL_SIZE]),)*2
    
    output, state_new = lstm_cell(sample_input, state)
    
    init_op = tf.global_variables_initializer()
    
    sess = tf.Session()
    sess.run(init_op)
    print sess.run(output)
    

    请注意,state_is_tuple=True因此在传递state给this时cell,它必须采用tuple表格形式。c_state并且m_state可能是“内存状态”和“单元状态”,尽管老实说我不确定,因为这些术语仅在文档中提及。在代码和文件中,关于LSTM-字母hc通常用于表示“输出值”和“单元状态”。
    http://colah.github.io/posts/2015-08-Understanding-
    LSTMs/

    这些张量表示单元的组合内部状态,应该一起传递。这样做的旧方法是简单地将它们连接起来,而新方法是使用元组。

    旧方法:

    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_CELL_SIZE, state_is_tuple=False)
    state = tf.zeros([1,LSTM_CELL_SIZE*2])
    
    output, state_new = lstm_cell(sample_input, state)
    

    新方法:

    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_CELL_SIZE, state_is_tuple=True)
    state = (tf.zeros([1,LSTM_CELL_SIZE]),)*2
    
    output, state_new = lstm_cell(sample_input, state)
    

    因此,基本上我们所做的一切,都state从长度的1张量更改为长度的42张量2。内容保持不变。[0,0,0,0]成为([0,0],[0,0])。(这应该使其速度更快)



知识点
面圈网VIP题库

面圈网VIP题库全新上线,海量真题题库资源。 90大类考试,超10万份考试真题开放下载啦

去下载看看